from typing import Optional, Tuple
import torch
from torch import Tensor, nn

from .blocks import BasicBlock2D as BasicBlock
from .blocks import BasicBlock2DNoFinalRelu
from .blocks import *

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702)
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914


class DTNet(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    def forward_L1_interim(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,only_final_and_next=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        previous_interim_thought = interim_thought
        noise_loss = 0
        noise_interims = [interim_thought]
        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            
            ## L1 interim
            if not only_final_and_next:
                noise_loss += torch.mean(torch.abs(interim_thought - previous_interim_thought))
                previous_interim_thought = interim_thought
            else:
                noise_interims.append(interim_thought)

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if only_final_and_next:
            final_interim = noise_interims[-1]
            for i in range(len(noise_interims) - 1):
                noise_loss += torch.mean(torch.abs(noise_interims[i] - final_interim))
            

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, noise_loss
            else:
                return out, interim_thought, noise_loss

        return all_outputs


def dt_net_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=False)


def dt_net_recall_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_with_bias(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,bias=True)


def dt_net_recall_2d_1block(width, **kwargs):
    return DTNet(BasicBlock, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_1block_noprop(width, **kwargs):
    return DTNet(BasicBlock2DNoProp, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_onlyrecall(width, **kwargs):
    return DTNet(BasicBlock2DNoProp, [], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_2block1x1(width, **kwargs):
    return DTNet(BasicBlock2D1x1, [2], width=width, in_channels=kwargs["in_channels"], recall=True)




def dt_net_gn_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=False, group_norm=True)


def dt_net_recall_gn_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True)

class DTNet5x5(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=5,
                              stride=1, padding=2, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=5,
                                stride=1, padding=2, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=5,
                               stride=1, padding=2, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=5,
                               stride=1, padding=2, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=5,
                               stride=1, padding=2, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_5x5_bblock2(width, **kwargs):
    return DTNet5x5(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_5x5(width, **kwargs):
    return DTNet5x5(BasicBlock2D5x5, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNet7x7(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=7,
                              stride=1, padding=3, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=7,
                                stride=1, padding=3, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=7,
                               stride=1, padding=3, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=7,
                               stride=1, padding=3, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=7,
                               stride=1, padding=3, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_7x7(width, **kwargs):
    return DTNet7x7(BasicBlock2D7x7, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNetOnlyRecallProp(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=1,
                              stride=1, padding=0, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    def forward_L1_interim(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,only_final_and_next=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        previous_interim_thought = interim_thought
        noise_loss = 0
        noise_interims = [interim_thought]
        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            
            ## L1 interim
            if not only_final_and_next:
                noise_loss += torch.mean(torch.abs(interim_thought - previous_interim_thought))
                previous_interim_thought = interim_thought
            else:
                noise_interims.append(interim_thought)

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if only_final_and_next:
            final_interim = noise_interims[-1]
            for i in range(len(noise_interims) - 1):
                noise_loss += torch.mean(torch.abs(noise_interims[i] - final_interim))
            

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, noise_loss
            else:
                return out, interim_thought, noise_loss

        return all_outputs

def dt_net_recall_2d_2b_onlyrecall_prop(width, **kwargs):
    return DTNetOnlyRecallProp(BasicBlock2DNoProp, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_1b_onlyrecall_prop(width, **kwargs):
    return DTNetOnlyRecallProp(BasicBlock2DNoProp, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_0b_onlyrecall_prop(width, **kwargs):
    return DTNetOnlyRecallProp(BasicBlock2DNoProp, [], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_onlyrecurrent3x3(width, **kwargs):
    return DTNetOnlyRecallProp(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNet_Conv1(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=1,
                              stride=1, padding=0, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
                                stride=1, padding=0, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    def forward_L1_interim(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,only_final_and_next=False,custom_beta=0, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # previous_interim_thought = interim_thought
        noise_loss = 0
        # noise_interims = [interim_thought]
        
        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought_noise = self.recur_block(interim_thought + torch.randn_like(interim_thought) * interim_thought * custom_beta)
            interim_thought = self.recur_block(interim_thought)
            
            ## L1 interim
            if i>2:
                noise_loss += torch.mean(torch.abs(interim_thought_noise - interim_thought.detach()))

            out = self.head(interim_thought)
            all_outputs[:, i] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, noise_loss
            else:
                return out, interim_thought, noise_loss

        return all_outputs


def dt_net_recall_2d_1conv_2bblock(width, **kwargs):
    return DTNet_Conv1(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_1conv_1bblock(width, **kwargs):
    return DTNet_Conv1(BasicBlock, [1], width=width, in_channels=kwargs["in_channels"], recall=True)


# BasicBlock2DSimple
# BasicBlock2DSimple2

def dt_net_recall_2d_1conv_1prop1block(width, **kwargs):
    return DTNet_Conv1(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_1conv_2prop1block(width, **kwargs):
    return DTNet_Conv1(BasicBlock2DPropBegin, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_1conv_1prop1blocklast(width, **kwargs):
    return DTNet_Conv1(BasicBlock2DPropEnd, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_1x1(width, **kwargs):
    return DTNet_Conv1(BasicBlock2D1x1, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNetReduceMaxPool(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False, use_AvgPool=False,**kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1)
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_pool,
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,  **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)

            all_outputs[:, i] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_2d_out10(width, **kwargs):
    return DTNetReduceMaxPool(BasicBlock, [2], width=width, output_size=10, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_out4(width, **kwargs):
    return DTNetReduceMaxPool(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)

def dt_net_2d_out4(width, **kwargs):
    return DTNetReduceMaxPool(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=False)

def dt_net_2d_out4_avg(width, **kwargs):
    return DTNetReduceMaxPool(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=False,use_AvgPool=True)

def dt_net_recall_2d_out4_1block(width, **kwargs):
    return DTNetReduceMaxPool(BasicBlock, [1], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)


class DTNetReduceMaxPool_1Conv(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=1,
                              stride=1, padding=0, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
                                stride=1, padding=0, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=1,
                               stride=1, padding=0, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_pool,
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,  **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)

            all_outputs[:, i] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_2d_out4_1conv(width, **kwargs):
    return DTNetReduceMaxPool_1Conv(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_out4_1conv_1block(width, **kwargs):
    return DTNetReduceMaxPool_1Conv(BasicBlock, [1], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)



class DTNetReduceMaxPoolEnd(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False, use_AvgPool=False,**kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1)
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3,
                                  head_pool)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,  **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)

            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_2d_out4_end(width, **kwargs):
    return DTNetReduceMaxPoolEnd(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_out4_avg_end(width, **kwargs):
    return DTNetReduceMaxPoolEnd(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True,use_AvgPool=True)


def dt_net_recall_2d_out10_end(width, **kwargs):
    return DTNetReduceMaxPoolEnd(BasicBlock, [2], width=width, output_size=10, in_channels=kwargs["in_channels"], recall=True)




from torch.autograd import Variable
import numpy as np

class Conv2dGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, kernel_size, bias=True):
        super(Conv2dGRUCell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        if type(kernel_size) == tuple and len(kernel_size) == 2:
            self.kernel_size = kernel_size
            self.padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        elif type(kernel_size) == int:
            self.kernel_size = (kernel_size, kernel_size)
            self.padding = (kernel_size // 2, kernel_size // 2)
        else:
            raise ValueError("Invalid kernel size.")

        self.bias = bias
        self.x2h = nn.Conv2d(in_channels=input_size,
                             out_channels=hidden_size * 3,
                             kernel_size=self.kernel_size,
                             padding=self.padding,
                             bias=bias)

        self.h2h = nn.Conv2d(in_channels=hidden_size,
                             out_channels=hidden_size * 3,
                             kernel_size=self.kernel_size,
                             padding=self.padding,
                             bias=bias)
        self.reset_parameters()


    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, input, hx=None):

        # Inputs:
        #       input: of shape (batch_size, input_size, height_size, width_size)
        #       hx: of shape (batch_size, hidden_size, height_size, width_size)
        # Outputs:
        #       hy: of shape (batch_size, hidden_size, height_size, width_size)

        if hx is None:
            hx = Variable(input.new_zeros(input.size(0), self.hidden_size, input.size(2), input.size(3)))

        x_t = self.x2h(input)
        h_t = self.h2h(hx)


        x_reset, x_upd, x_new = x_t.chunk(3, 1)
        h_reset, h_upd, h_new = h_t.chunk(3, 1)

        reset_gate = torch.sigmoid(x_reset + h_reset)
        update_gate = torch.sigmoid(x_upd + h_upd)
        new_gate = torch.tanh(x_new + (reset_gate * h_new))

        hy = update_gate * hx + (1 - update_gate) * new_gate

        return hy
    
# class GRUUpdateRule(nn.Module):

#     def __init__(self, in_planes, stride=1,inner_channels=1):
#         super().__init__()
#         self.inner_channels=inner_channels # out
#         self.conv1 = nn.Conv2d(in_planes, inner_channels, kernel_size=3,
#                                stride=stride, padding=1, bias=True)
        
#         self.conv2 = nn.Conv2d(in_planes, inner_channels, kernel_size=3,
#                                stride=stride, padding=1, bias=True)
        
#         self.reset_parameters()


#     def reset_parameters(self):
#         std = 1.0 / np.sqrt(self.hidden_size)
#         for w in self.parameters():
#             w.data.uniform_(-std, std)

#     def forward(self, old_state,new_state):

#         # Inputs:
#         #       input: of shape (batch_size, input_size, height_size, width_size)
#         #       hx: of shape (batch_size, hidden_size, height_size, width_size)
#         # Outputs:
#         #       hy: of shape (batch_size, hidden_size, height_size, width_size)

#         # if hx is None:
#         #     hx = Variable(input.new_zeros(input.size(0), self.hidden_size, input.size(2), input.size(3)))

#         x_t = self.conv1(input)
#         h_t = self.conv2(old_state)


#         x_reset, x_upd = x_t.chunk(2, 1)
#         h_reset, h_upd = h_t.chunk(2, 1)

#         reset_gate = torch.sigmoid(x_reset + h_reset)
#         update_gate = torch.sigmoid(x_upd + h_upd)
#         new_gate = torch.tanh(x_new + (reset_gate * h_new))

#         hy = update_gate * old_state + (1 - update_gate) * new_gate

#         return hy
    

class UpdateStateLayer(nn.Module):
    """Basic residual block class 2D"""

    def __init__(self, in_planes, stride=1,inner_channels=1):
        super().__init__()
        self.inner_channels=inner_channels # out
        self.conv1 = nn.Conv2d(in_planes, inner_channels, kernel_size=3,
                               stride=stride, padding=1, bias=True)
        
        self.conv2 = nn.Conv2d(in_planes, inner_channels, kernel_size=3,
                               stride=stride, padding=1, bias=True)
        
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.inner_channels)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, old_state, new_state):
        update_gate = (self.conv1(old_state) + self.conv2(new_state)).sum(1,keepdim=True) # (batch_size, 1, height, width)
        update_gate = torch.sigmoid(update_gate)

        return update_gate * old_state + (1 - update_gate) * new_state

class UpdateStateLayerMultiChannel(nn.Module):
    """Basic residual block class 2D"""

    def __init__(self, in_planes, stride=1):
        super().__init__()
        self.inner_channels=in_planes # out
        self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3,
                               stride=stride, padding=1, bias=True)
        
        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3,
                               stride=stride, padding=1, bias=True)
        
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.inner_channels)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, old_state, new_state):
        update_gate = self.conv1(old_state) + self.conv2(new_state) # (batch_size, 1, height, width)
        update_gate = torch.sigmoid(update_gate)

        return update_gate * old_state + (1 - update_gate) * new_state


class DTNetUpdateRule(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, update_inner_channels=1,update_multi_channel=False,**kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        if update_multi_channel:
            self.update_layer = UpdateStateLayerMultiChannel(width)
        else:
            self.update_layer = UpdateStateLayer(width,inner_channels=update_inner_channels)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i,strd in enumerate(strides):
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            ## append relu after each layer except the last one
            if i < len(strides) - 1:
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Tanh())
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        # interim_thought_hidden = torch.zeros_like(interim_thought).to(interim_thought.device)
        for i in range(iters_to_do):
            if self.recall:
                interim_thought_input = torch.cat([interim_thought, x], 1)
            interim_thought_new = self.recur_block(interim_thought_input)

            interim_thought = self.update_layer(interim_thought,interim_thought_new)
            # interim_thought = interim_thought_hidden

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_update_rule_2d(width, **kwargs):
    return DTNetUpdateRule(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_update_rule_morehidden_2d(width, **kwargs):
    return DTNetUpdateRule(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True, update_inner_channels=20)



class DTNetUpdateRuleV2(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, update_inner_channels=1,update_multi_channel=False,**kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        
        if update_multi_channel:
            self.update_layer = UpdateStateLayerMultiChannel(width)
        else:
            self.update_layer = UpdateStateLayer(width,inner_channels=update_inner_channels)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i,strd in enumerate(strides):
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            ## append relu after each layer except the last one
            if i < len(strides) - 1:
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Tanh())
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        initial_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        for i in range(iters_to_do):
            if self.recall:
                interim_thought_input = torch.cat([interim_thought, x], 1)
            interim_thought_new = self.recur_block(interim_thought_input)

            interim_thought = self.update_layer(interim_thought,interim_thought_new)
            # interim_thought = interim_thought_hidden

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_update_rule_v2_2d(width, **kwargs):
    return DTNetUpdateRuleV2(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_update_rule_v2_morehidden_2d(width, **kwargs):
    return DTNetUpdateRuleV2(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True, update_inner_channels=20)

def dt_net_recall_uprule_v2_multic_2d(width, **kwargs):
    return DTNetUpdateRuleV2(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True, update_multi_channel=True)


class DTNetUpdateRuleReduceMaxPool(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False, use_AvgPool=False,update_inner_channels=1,**kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1)
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)

        head_conv1 = nn.Conv2d(width, width, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_pool,
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        self.output_size = output_size

        self.update_layer = UpdateStateLayer(width,inner_channels=update_inner_channels)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i,strd in enumerate(strides):
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            ## append relu after each layer except the last one
            if i < len(strides) - 1:
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Tanh())
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,  **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought_input = torch.cat([interim_thought, x], 1)
            interim_thought_new = self.recur_block(interim_thought_input)
            interim_thought = self.update_layer(interim_thought,interim_thought_new)

            out = self.head(interim_thought).view(x.size(0), self.output_size)

            all_outputs[:, i] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
# def dt_net_recall_2d_out4(width, **kwargs):
#     return DTNetReduceMaxPool(BasicBlock, [2], width=width, output_size=4, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_update_rule_morehidden_2d_out4(width, **kwargs):
    return DTNetUpdateRuleReduceMaxPool(BasicBlock2DNoFinalRelu, [2], output_size=4,width=width, in_channels=kwargs["in_channels"], recall=True, update_inner_channels=20)




class DTNetCustom(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.name = "DTNetCustom"

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=False)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=False)

        recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        assert num_blocks == 1

        # for i in range(num_blocks):
        #     # if i==0 and rec
        #     # recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))
        #     if i==0 and recall:
        #         recur_layers.append(block(self.width+in_channels, self.width, 1, group_norm=self.group_norm))
        #     else:
        #         recur_layers.append(block(self.width, self.width, 1, group_norm=self.group_norm))



        if recall:
            self.recur_block = block(self.width+in_channels, self.width, 1, group_norm=self.group_norm)
        else:
            self.recur_block = block(self.width, self.width, 1, group_norm=self.group_norm)
            



        # recur_layers

        out_conv = nn.Conv2d(width, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        # head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
        #                        stride=1, padding=1, bias=False)
        # head_conv2 = nn.Conv2d(32, 2, kernel_size=1,
        #                        stride=1, bias=False)
        # head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
        #                        stride=1, padding=1, bias=False)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        

        self.head = nn.Sequential(
                                out_conv
                                # head_conv1, nn.ReLU(),
                                #   head_conv2, nn.ReLU(),
                                #   head_conv3
                                  )

    # def _make_layer(self, block, planes, num_blocks, stride):
    #     strides = [stride] + [1]*(num_blocks-1)
    #     layers = []
    #     for strd in strides:
    #         layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
    #         self.width = planes * block.expansion
    #     return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = self.recur_block(torch.cat([interim_thought, x], 1), interim_thought)
            else:
                interim_thought = self.recur_block(interim_thought, interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    # def forward_L1_interim(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,only_final_and_next=False, **kwargs):
    #     # initial_thought = self.projection(x)

    #     if interim_thought is None:
    #         interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)


    #     all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
    #     previous_interim_thought = interim_thought
    #     noise_loss = 0
    #     noise_interims = [interim_thought]
    #     for i in range(iters_to_do):
    #         if self.recall:
    #             interim_thought = torch.cat([interim_thought, x], 1)
    #         interim_thought = self.recur_block(interim_thought)
            
    #         ## L1 interim
    #         if not only_final_and_next:
    #             noise_loss += torch.mean(torch.abs(interim_thought - previous_interim_thought))
    #             previous_interim_thought = interim_thought
    #         else:
    #             noise_interims.append(interim_thought)

    #         out = self.head(interim_thought)
    #         all_outputs[:, i] = out

    #     if only_final_and_next:
    #         final_interim = noise_interims[-1]
    #         for i in range(len(noise_interims) - 1):
    #             noise_loss += torch.mean(torch.abs(noise_interims[i] - final_interim))
            

    #     if self.training:
    #         if return_all_outputs:
    #             return all_outputs, out, interim_thought, noise_loss
    #         else:
    #             return out, interim_thought, noise_loss

    #     return all_outputs

def dt_net_custom_2d_1conv(width, **kwargs):
    return DTNetCustom(Block1Conv2d, 1, width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_custom_2d_1conv_add(width, **kwargs):
    return DTNetCustom(Block1Conv2dAdd, 1, width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_custom_2d_simple_block(width, **kwargs):
    return DTNetCustom(BasicBlock2DSimple, 1, width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_custom_2d_1conv_tanh(width, **kwargs):
    return DTNetCustom(Block1Conv2dTanh, 1, width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_custom_2d_1conv_add_tanh(width, **kwargs):
    return DTNetCustom(Block1Conv2dAddTanh, 1, width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_custom_2d_simple_block_tanh(width, **kwargs):
    return DTNetCustom(BasicBlock2DSimpleTanh, 1, width=width, in_channels=kwargs["in_channels"], recall=True)




class DTNetGruConv(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.name = "DTNetCustom"

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=False)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=False)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        assert num_blocks == 1

        # for i in range(num_blocks):
        #     # if i==0 and rec
        #     # recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))
        #     if i==0 and recall:
        #         recur_layers.append(block(self.width+in_channels, self.width, 1, group_norm=self.group_norm))
        #     else:
        #         recur_layers.append(block(self.width, self.width, 1, group_norm=self.group_norm))



        # if recall:
        #     self.recur_block = block(self.width+in_channels, self.width, 1, group_norm=self.group_norm)
        # else:
        #     self.recur_block = block(self.width, self.width, 1, group_norm=self.group_norm)
            

        self.recur_block = ConvGRUCell(in_channels,width,(3,3),bias=True)

        assert self.recur_block.padding == (1,1)

        # recur_layers

        out_conv = nn.Conv2d(width, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        # head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
        #                        stride=1, padding=1, bias=False)
        # head_conv2 = nn.Conv2d(32, 2, kernel_size=1,
        #                        stride=1, bias=False)
        # head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
        #                        stride=1, padding=1, bias=False)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        

        self.head = nn.Sequential(
                                out_conv
                                # head_conv1, nn.ReLU(),
                                #   head_conv2, nn.ReLU(),
                                #   head_conv3
                                  )

        assert recall

    # def _make_layer(self, block, planes, num_blocks, stride):
    #     strides = [stride] + [1]*(num_blocks-1)
    #     layers = []
    #     for strd in strides:
    #         layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
    #         self.width = planes * block.expansion
    #     return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            # if self.recall:
            
            interim_thought = self.recur_block(x, interim_thought)
            
            # else:
            #     interim_thought = self.recur_block(interim_thought, interim_thought)
            
            
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs



def dt_net_gru_2d(width, **kwargs):
    return DTNetGruConv(1, width=width, in_channels=kwargs["in_channels"], recall=True)




class DTNetNoDiagonal(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        
        self.kernel_mask = torch.ones((1,1,3,3))
        self.kernel_mask[:,:,0,0] = 0
        self.kernel_mask[:,:,0,2] = 0
        self.kernel_mask[:,:,2,0] = 0
        self.kernel_mask[:,:,2,2] = 0

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)
    
    def _set_zero_kernel_diagonal(self):
        ## all kernels are 3x3

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.kernel_size[0] == 3:
                    if m.weight.device != self.kernel_mask.device:
                        self.kernel_mask = self.kernel_mask.to(m.weight.device)

                    m.weight.data = m.weight.data * self.kernel_mask


    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        
        self._set_zero_kernel_diagonal()
        
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_2d_only_crosses(width, **kwargs):
    return DTNetNoDiagonal(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)



##### reduce the messages

class ReccurentOutput(nn.Module):
    def __init__(self, net_activation,filter_activation,**kwargs) -> None:
        super().__init__()
        self.net_activation = net_activation
        self.filter_activation = filter_activation
    
    def forward(self, x):
        # x -> (batch, channels, height, width)
        x[:,:-1,:,:] = self.net_activation(x[:,:-1,:,:] )
        x[:,-1,:,:] = self.filter_activation(x[:,-1,:,:] )
        return x

class ReccurentReluLinear(nn.Module):
    def __init__(self,**kwargs) -> None:
        super().__init__()
    
    def forward(self, x):
        # x -> (batch, channels, height, width) 
        F.relu(x[:,:-1,:,:],inplace=True)
        return x

## clamped straight through abs
from torch.autograd import Function
class _ClampAbsST(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.clamp(input.abs(),0,1)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input > 0] = 1
        grad_input[input < 0] = -1
        return grad_input

clamp_abs_st = _ClampAbsST.apply

class _ClapReluST(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.clamp(F.relu(input),0,1)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
clamp_relu_st = _ClapReluST.apply

class _ClampReluLinear(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.clamp(F.relu(input),0,1)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input

clamp_relu_linear = _ClampReluLinear.apply

class _StepST(Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input>0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input
    
step_st = _StepST.apply

class SampleMask(nn.Module):
    def __init__(self, **kwargs) -> None:
        super().__init__()
    
    def forward(self, x):
        if self.training:
            x[:,-1,:,:] = torch.bernoulli(x[:,-1,:,:]) + x[:,-1,:,:] - x[:,-1,:,:].detach()
        else:
            x[:,-1,:,:] = (x[:,-1,:,:]>0.5).float()

class DTNetReduceMessages(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                 filter_activation=F.sigmoid,block_activation_mode='relu_plus_filter',
                  b_inp_mode='b_inp_external', sample_mask = False, **kwargs):
        super().__init__()

        self.name = "DTNetReduceMessages"

        self.bias = bias

        self.recall = recall
        width = int(width)+1
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        assert block_activation_mode in ['relu_plus_filter', 'all_relu', 'relu_plus_linear']
        self.block_activation_mode = block_activation_mode
        self.filter_activation = filter_activation

        self.recurrent_activation = ReccurentOutput(F.relu,filter_activation)

        assert b_inp_mode in ['b_inp_external', 'b_inp_internal']
        self.b_inp_mode = b_inp_mode

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        # FIXME
        self.projection = nn.Sequential(proj_conv, 
                                        # self.recurrent_activation
                                        )
        
        
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        
        if sample_mask:
            self.sample_mask = SampleMask()
        else:
            self.sample_mask = nn.Identity()

        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        block_nrs = len(strides)
        for i in range(block_nrs):
            layers.append(block(self.width, planes, strides[i], group_norm=self.group_norm, bias=self.bias))

            # :
            #     ## last one should be desired output
            #     layers.append(self.recurrent_activation)
            if i != block_nrs-1:
                # layers.append(nn.ReLU())

                if self.block_activation_mode == 'all_relu':
                    layers.append(nn.ReLU())

                elif self.block_activation_mode == 'relu_plus_filter':
                    layers.append(ReccurentOutput(nn.ReLU(),self.filter_activation))

                elif self.block_activation_mode == 'relu_plus_linear':
                    layers.append(ReccurentReluLinear())
                
                else:
                    raise NotImplementedError


            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def get_mask_loss(self,mask):
        # sao praticamente todos isto
        if self.filter_activation == clamp_relu_linear or self.filter_activation == clamp_relu_linear_1 or self.filter_activation == clamp_relu_linear_centered or self.filter_activation == step_st:
            return torch.mean(torch.clamp(mask,min=0)) ## dont do l1 after 0
        else:
            return torch.mean(mask)
    

        # if self.filter_activation == F.sigmoid or self.filter_activation == torch.sigmoid or self.filter_activation == sig_1:
        # else:
        #     raise NotImplementedError

    def get_iterim_next(self, interim_thought):
        if self.b_inp_mode == 'b_inp_external':
            iterim = self.recurrent_activation(interim_thought)
            mask = iterim[:,-1:,:,:] .unsqueeze(-1)
        elif self.b_inp_mode == 'b_inp_internal':
            if self.block_activation_mode == 'all_relu':
                iterim = F.relu(interim_thought)
                mask = self.filter_activation(interim_thought[:,-1:,:,:])

            elif self.block_activation_mode == 'relu_plus_filter':
                iterim = self.recurrent_activation(interim_thought)
                mask = iterim[:,-1:,:,:]

            elif self.block_activation_mode == 'relu_plus_linear':
                # layers.append(ReccurentReluLinear())
                F.relu(interim_thought[:,:-1,:,:] ,inplace=True)
                iterim = interim_thought
                mask = self.filter_activation(interim_thought[:,-1:,:,:])
            
            else:
                raise NotImplementedError
            
        # elif self.b_inp_mode == 'b_inp_detached':
        #     iterim = self.recurrent_activation(interim_thought)
        #     mask = iterim[:,-1:,:,:] .unsqueeze(-1)
        #     return iterim, mask
        else:
            raise NotImplementedError

        mask = self.sample_mask(mask)        

        return iterim, mask
    
    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, return_masks=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            raise NotImplementedError

        interim_thought,mask = self.get_iterim_next(interim_thought)


        # mask = interim_thought_previous[:,:,:,-1]
        ### FIX me, devia pensar melhor no initial thought, meti de momento com o recurrent activation

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        all_masks = torch.zeros((x.size(0), iters_to_do, 1, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            interim_thought_previous = interim_thought

            
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            # interim_thought = self.recur_block(interim_thought)
            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

            interim_thought_next = self.recur_block(interim_thought)

            # final activation

            interim_thought_next, next_mask = self.get_iterim_next(interim_thought_next)
            
            interim_thought = interim_thought_next * mask + interim_thought_previous * (1-mask)

            out = self.head(interim_thought)
            all_outputs[:, i] = out
            all_masks[:, i] = mask

            mask = next_mask

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            elif return_masks:
                return out, interim_thought, all_masks
            else:
                return out, interim_thought

        return all_outputs
    
def sig_1(x):
    return torch.sigmoid(x-1)

def sig_p1(x):
    return torch.sigmoid(x+1)

def clamp_relu_centered(x):
    return clamp_relu_st(x+0.5)

def clamp_relu_1(x):
    return clamp_relu_st(x+1)

def clamp_relu_linear_centered(x):
    return clamp_relu_linear(x+0.5)

def clamp_relu_linear_1(x):
    return clamp_relu_linear(x+1)

def step_1(x):
    return step_st(x+1)

filter_activations= {'sig':torch.sigmoid,
                     'sig-1':sig_1,
                    'sig_p1':sig_p1,
                     'abs':clamp_abs_st,
                     'relu':clamp_relu_st,
                     'relu-0.5':clamp_relu_centered,
                    'relu-1':clamp_relu_1,

                    'relu_linear':clamp_relu_linear,
                    'relu_linear-0.5':clamp_relu_linear_centered,
                    'relu_linear-1':clamp_relu_linear_1,

                    'step':step_st,
                    'step_1': step_1,
                     }

use_bias = {'bias':True,
            'nobias': False}

internal_block_activation_modes = {'all_relu':'all_relu',
                                   'relu_plus_filter':'relu_plus_filter',
                                   'relu_plus_linear':'relu_plus_linear'}

b_inp_modes = {'b_inp_external':'b_inp_external',
                'b_inp_internal':'b_inp_internal'}

sample_masks = {'sample':True,
                'no_sample':False}


for filter_activation_k,filter_activation_v in filter_activations.items():
    for use_bias_k,use_bias_v in use_bias.items():
        for internal_block_activation_mode_k,internal_block_activation_mode_v in internal_block_activation_modes.items():
            for b_inp_mode_k,b_inp_mode_v in b_inp_modes.items():
                for sample_masks_k,sample_masks_v in sample_masks.items():
                    function_name = f'dt_reduce_message_{use_bias_k}_{filter_activation_k}_{b_inp_mode_k}_{internal_block_activation_mode_k}_{sample_masks_k}_2d'

                    globals()[function_name] = lambda width,**kwargs: DTNetReduceMessages(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True, filter_activation=filter_activation_v, bias=use_bias_v, block_activation_mode=internal_block_activation_mode_v, b_inp_mode=b_inp_mode_v, sample_masks=sample_masks_v)

# def dt_net_reduce_message_2d(width, **kwargs):
#     return DTNetReduceMessages(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


# def dt_net_reduce_message_2d_with_bias(width, **kwargs):
#     return DTNetReduceMessages(BasicBlock2DNoFinalRelu, [2], width=width, in_channels=kwargs["in_channels"], recall=True,bias=True)


###

class FlowNet_Conv3x3(nn.Module):
    """DeepThinking Network 2D model class"""

    @staticmethod
    def initialisation(m: nn.Module):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight,nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)


    def zero_last_recurrent_layer(self):
        # suficiente para 1000 iters, sem receios de simetria
        self.recur_block[-1][-1].conv2.weight.data*=0.0001
        
        # # em principio a simetria e' destruida com os filtros seguintes
        # self.recur_block[-1][-1].conv2.weight.data.zero_() 

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,**kwargs):
        super().__init__()

        self.name = "FlowNet"

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=False)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
        #                         stride=1, padding=0, bias=False)

        self.in_planes = width

        recur_layers = []
        if recall:
            self.in_planes += in_channels

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(self.in_planes, 32, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        

        self.apply(self.initialisation)

        # avoid explosion
        self.zero_last_recurrent_layer()
        

    def _make_layer(self, block, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        assert block.expansion == 1
        for strd in strides:
            layers.append(block(self.in_planes, self.width, strd))
            # self.width = self.width * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,return_recur_inter=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            if self.recall:
                interim_thought = torch.cat([initial_thought, x], 1)
            else:
                interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            interim_thought, recur_inter = self.recur_block(interim_thought) ## fixme if recurrent inter

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            assert not return_all_outputs

            if return_recur_inter:
                return out, interim_thought, recur_inter
            else:
                return out, interim_thought

        return all_outputs


def flownet_recall_2d_3x3_1block3x3(width, **kwargs):
    return FlowNet_Conv3x3(FlowBlock2D3x3, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def flownet_recall_2d_3x3_2block3x3(width, **kwargs):
    return FlowNet_Conv3x3(FlowBlock2D3x3, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def flownet_recall_2d_3x3_1block3x3_tanh(width, **kwargs):
    return FlowNet_Conv3x3(FlowBlock2D3x3Tanh, [1], width=width, in_channels=kwargs["in_channels"], recall=True)



class FlowNet_Conv1(nn.Module):
    """DeepThinking Network 2D model class"""

    @staticmethod
    def initialisation(m: nn.Module):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight,nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)


    def zero_last_recurrent_layer(self):
        # suficiente para 1000 iters, sem receios de simetria
        self.recur_block[-1][-1].conv2.weight.data*=0.0001
        
        # # em principio a simetria e' destruida com os filtros seguintes
        # self.recur_block[-1][-1].conv2.weight.data.zero_() 

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,**kwargs):
        super().__init__()

        self.name = "FlowNet"

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=1,
                              stride=1, padding=0, bias=False)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
        #                         stride=1, padding=0, bias=False)

        self.in_planes = width

        recur_layers = []
        if recall:
            self.in_planes += in_channels

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(self.in_planes, 32, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        

        self.apply(self.initialisation)

        # avoid explosion
        self.zero_last_recurrent_layer()
        

    def _make_layer(self, block, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        assert block.expansion == 1
        for strd in strides:
            layers.append(block(self.in_planes, self.width, strd))
            # self.width = self.width * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,return_recur_inter=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            if self.recall:
                interim_thought = torch.cat([initial_thought, x], 1)
            else:
                interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            interim_thought, recur_inter = self.recur_block(interim_thought) ## fixme if recurrent inter

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            assert not return_all_outputs

            if return_recur_inter:
                return out, interim_thought, recur_inter
            else:
                return out, interim_thought

        return all_outputs

def flownet_recall_2d_1conv_1prop1block(width, **kwargs):
    return FlowNet_Conv1(FlowBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def flownet_recall_2d_1conv_1prop1block_tanh(width, **kwargs):
    return FlowNet_Conv1(FlowBlock2DPropBeginTanh, [1], width=width, in_channels=kwargs["in_channels"], recall=True)




#####################


# ---


class Scaling(nn.Module):

    def __init__(self, scale: float):
        super().__init__()
        self.scale = scale
        
    
    def __repr__(self):
        return f"{self.__class__.__name__}({self.scale})"
    
    def forward(self, x):
        return self.scale * x



class NFResidualBottleneck(nn.Module):
    ## FIX activation functions
    expansion = 4

    def __init__(self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: nn.Module = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        alpha: float = 1.,
        beta: float = 1.,
        no_preact: bool = False,
    ):
        super().__init__()
        self.beta = beta

        ### pre-activations ###
        preact_layers = [] if no_preact else [
            Scaling(alpha),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([inplanes,1,1,1],2,None),
        ]
        if downsample is None:
            self.preact = nn.Identity()
            residual_preact = preact_layers
        else:
            self.preact = nn.Sequential(*preact_layers)
            residual_preact = []
        ### pre-activations ###

        kernel_size = 3
        width = groups * (planes * base_width // 64)
        self.downsample = nn.Identity() if downsample is None else downsample
        self.residual_branch = nn.Sequential(
            *residual_preact,
            nn.Conv2d(inplanes, width, 1, bias=True),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([width,1,1,1],3,None),
            nn.Conv2d(width, width, kernel_size, stride, padding=dilation,
                      dilation=dilation, groups=groups, bias=True),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([width,1,1,1],4,None),
            nn.Conv2d(width, planes * self.expansion, 1, bias=True),
        )

    def forward(self, x):
        x = self.preact(x)
        skip = self.downsample(x)
        residual = self.residual_branch(x)
        return self.beta * residual + skip
    
    @torch.no_grad()
    def signal_prop(self, x, dim=(0, -1, -2)):
        # forward code
        x = self.preact(x)
        skip = self.downsample(x)
        residual = self.residual_branch(x)
        out = self.beta * residual + skip

        # compute necessary statistics
        out_mu2 = torch.mean(out.mean(dim) ** 2).item()
        out_var = torch.mean(out.var(dim)).item()
        res_var = torch.mean(residual.var(dim)).item()
        return out, (out_mu2, out_var, res_var)


class NFBasicBlock(nn.Module):
    ### FIXME incompleted
    ## FIX activation functions
    expansion = 1

    def __init__(self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: nn.Module = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        alpha: float = 1.,
        beta: float = 1.,
        no_preact: bool = False,
    ):
        super().__init__()
        self.beta = beta


        ### no preact on first one, regular resnet doesnt have this

        ### pre-activations ###
        preact_layers = [] if no_preact else [
            Scaling(alpha),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([inplanes,1,1,1],2,None),
        ]
        if downsample is None:
            self.preact = nn.Identity()
            residual_preact = preact_layers
        else:
            self.preact = nn.Sequential(*preact_layers)
            residual_preact = []
        ### pre-activations ###

        kernel_size = 3
        width = groups * (planes * base_width // 64)
        self.downsample = nn.Identity() if downsample is None else downsample
        self.residual_branch = nn.Sequential(
            *residual_preact,
            nn.Conv2d(inplanes, width, 1, bias=True),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([width,1,1,1],3,None),
            nn.Conv2d(width, width, kernel_size, stride, padding=dilation,
                      dilation=dilation, groups=groups, bias=True),
            nn.ReLU(),
            # GLOBAL_ACTIVATION_MODULE([width,1,1,1],4,None),
            nn.Conv2d(width, planes * self.expansion, 1, bias=True),
        )

    def forward(self, x):
        x = self.preact(x)
        skip = self.downsample(x)
        residual = self.residual_branch(x)
        return self.beta * residual + skip
    
    @torch.no_grad()
    def signal_prop(self, x, dim=(0, -1, -2)):
        # forward code
        x = self.preact(x)
        skip = self.downsample(x)
        residual = self.residual_branch(x)
        out = self.beta * residual + skip

        # compute necessary statistics
        out_mu2 = torch.mean(out.mean(dim) ** 2).item()
        out_var = torch.mean(out.var(dim)).item()
        res_var = torch.mean(residual.var(dim)).item()
        return out, (out_mu2, out_var, res_var)

class NFResidualNetwork(nn.Module):

    @staticmethod
    def initialisation(m: nn.Module):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def __init__(self, layers: tuple, num_classes: int = 1000, beta: float = 1.):
        super().__init__()
        block = NFResidualBottleneck
        self._inplanes = 64
        self._expected_var = 1.
        self.beta = beta

        self.intro = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.subnet1 = self._make_subnet(block, 64, layers[0], no_preact=True)
        self.subnet2 = self._make_subnet(block, 128, layers[1], stride=2)
        self.subnet3 = self._make_subnet(block, 256, layers[2], stride=2)
        self.subnet4 = self._make_subnet(block, 512, layers[3], stride=2)

        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512 * block.expansion, num_classes),
        )

        self.apply(self.initialisation)
        # self.apply(CentredWeightNormalization(dim=(1, 2, 3)))
    
    def _make_subnet(self, block, planes: int, num_layers: int, 
                     stride: int = 1, no_preact: bool = False):
        downsample = None
        if stride != 1 or self._inplanes != planes * block.expansion:
            downsample = nn.Conv2d(self._inplanes, planes * block.expansion, 1, stride)
        
        layers = []
        # compute expected variance analytically
        alpha = 1. / self._expected_var ** .5
        self._expected_var = 1. + self.beta ** 2
        layers.append(block(
            self._inplanes, planes, stride, downsample,
            alpha=alpha, beta=self.beta, no_preact=no_preact
        ))
        self._inplanes = planes * block.expansion
        for _ in range(1, num_layers):
            # track expected variance analytically
            alpha = 1. / self._expected_var ** .5
            self._expected_var += self.beta ** 2
            layers.append(block(
                self._inplanes, planes, alpha=alpha, beta=self.beta
            ))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.intro(x)
        x = self.subnet1(x)
        x = self.subnet2(x)
        x = self.subnet3(x)
        x = self.subnet4(x)
        return self.classifier(x)

    @torch.no_grad()
    def signal_prop(self, x, dim=(0, -1, -2)):
        x = self.intro(x)

        statistics = [(
            torch.mean(x.mean(dim) ** 2).item(),
            torch.mean(x.var(dim)).item(),
            float('nan'),
        )]
        for subnet in (self.subnet1, self.subnet2, self.subnet3, self.subnet4):
            for layer in subnet:
                x, stats = layer.signal_prop(x, dim)
                statistics.append(stats)
        
        # convert list of tuples to tuple of lists
        sp = tuple(map(list, zip(*statistics)))
        return self.classifier(x), sp





class NFNet(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


# def dt_net_2d(width, **kwargs):
#     return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=False)


def nf_net_recall_2d(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)



class DTNetNoProjection(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros((x.size(0), self.width, x.size(2), x.size(3))).to(x.device)

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_2d_no_proj(width, **kwargs):
    return DTNet(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)




class DTNetGaussNoise(nn.Module):
    """DeepThinking Network 2D model class
    
    Use custom_beta for this
    """

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,alpha_noise=0, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            ## add noise
            if alpha_noise > 0:
                noise = torch.randn_like(interim_thought) * alpha_noise
                interim_thought = interim_thought + noise

            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_2d_noise(width, **kwargs):
    return DTNetGaussNoise(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNetGaussNoise1Conv(nn.Module):
    """DeepThinking Network 2D model class
    
    Use custom_beta for this
    """

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
                                stride=1, padding=0, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=1,
                               stride=1, padding=0, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=bias)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,alpha_noise=0, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            ## add noise
            if alpha_noise > 0:
                noise = torch.randn_like(interim_thought) * alpha_noise
                interim_thought = interim_thought + noise

            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_1conv_2d_noise_1p1b(width, **kwargs):
    return DTNetGaussNoise1Conv(BasicBlock2DPropBegin, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


class DTNetGaussNoiseOnce(nn.Module):
    """DeepThinking Network 2D model class
    
    Use custom_beta for this
    """

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,alpha_noise=0, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        apply_noise_iter = np.random.randint(iters_to_do)

        for i in range(iters_to_do):
            ## add noise
            if alpha_noise > 0 and i == apply_noise_iter:
                noise = torch.randn_like(interim_thought) * alpha_noise
                interim_thought = interim_thought + noise

            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_2d_noise_once(width, **kwargs):
    return DTNetGaussNoiseOnce(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)



###########################
### dreamer like discrete

import torch.nn.functional as F
import torch.distributions as tdist
# import numpy as np

from functools import cached_property


class OneHotDist(tdist.OneHotCategorical):

    def __init__(self, logits=None, probs=None, dtype=None, validate_args=False):
        self._sample_dtype = dtype or torch.float32  # FIXME tipo
        super().__init__(probs=probs, logits=logits,
                         validate_args=False)  # todo verify args, ignore for now due to error

        # FIXME event_shape -1 for now,because I think could be empty
        # if so, tf uses logits or probs shape[-1]

    @cached_property
    def mode(self):
        return F.one_hot(torch.argmax(self.logits, -1), self.event_shape[-1]).float()

    def sample(self, sample_shape=(), generator=None):  # note doenst have rsample
        # Straight through biased gradient estimator.

        sample = self._sample_with_generator(sample_shape,generator)
        probs = super().probs
        while len(probs.shape) < len(sample.shape):  # adds dims on 0
            probs = probs[None]

        sample += (probs - probs.detach())  # .type(self._sample_dtype)
        return sample

    def _sample_with_generator(self, sample_shape=torch.Size(), generator=None):
        sample_shape = torch.Size(sample_shape)
        probs = self._categorical.probs
        num_events = self._categorical._num_events

        probs_2d = probs.reshape(-1, num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True, generator=generator).T
        indices =  samples_2d.reshape(self._categorical._extended_shape(sample_shape))


        return torch.nn.functional.one_hot(indices, num_events).to(probs)




    # custom log_prob more stable
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        value, logits = torch.broadcast_tensors(value, self.logits)
        indices = value.max(-1)[1]

        # reshapes are not cool
        ret = -F.cross_entropy(logits.reshape(-1, *self.event_shape), indices.reshape(-1).detach(), reduction='none')

        return torch.reshape(ret, logits.shape[:-1])



class Independent(tdist.Independent):
    @property
    def mode(self):
        return self.base_dist.mode


class Bernoulli(tdist.Bernoulli):
    @property
    def mode(self):
        return torch.round(self.probs)  # >0.5
    
    def rsample(self, sample_shape=torch.Size()):
        # Straight through biased gradient estimator.
        probs = self.probs.expand(sample_shape + self.batch_shape + self.event_shape)
        sample = torch.bernoulli(probs)
        sample += (probs - probs.detach())
        return sample
    
    def sample(self, sample_shape=torch.Size()):
        return self.rsample(sample_shape)

class Discrete(tdist.Bernoulli):
    ## goes through a sigmoid, is more stable gradient

    @property
    def mode(self):
        return torch.round(self.probs)  # >0.5
    
    def rsample(self, sample_shape=torch.Size()):
        # Straight through biased gradient estimator.
        probs = self.probs.expand(sample_shape + self.batch_shape + self.event_shape)
        sample = torch.round(probs)
        sample += (probs - probs.detach())
        return sample
    
    def sample(self, sample_shape=torch.Size()):
        return self.rsample(sample_shape)

# class Discrete2():
# unstable mesmo com clip
#     def __init__(self,logits) -> None:
#         self.logits = logits
        
    

#     @property
#     def mode(self):
#         return (self.logits>=0).float()  # >0.5
    
#     def rsample(self, sample_shape=torch.Size()):
#         # Straight through biased gradient estimator.
#         # probs = self.probs.expand(sample_shape + self.batch_shape + self.event_shape)
#         sample = (self.logits>=0).float()
#         sample += (self.logits - self.logits.detach())
#         return sample
    
#     def sample(self, sample_shape=torch.Size()):
#         return self.rsample(sample_shape)




class ContinuousBernoulli(tdist.ContinuousBernoulli):
    @property
    def mode(self):
        return torch.round(self.probs)  # >0.5

    def sample(self, sample_shape=torch.Size()):
        return self.rsample(sample_shape)

## small latent, pode ajudar.. menos degrees of freedom
class DTNetSmallTest(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,stoch_classes=10, **kwargs):
        super().__init__()
        name="DTNetSmallTest"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.stoch_classes = stoch_classes
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(stoch_classes + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        conv_latent = nn.Conv2d(width,stoch_classes, kernel_size=1,
                                stride=1, padding=0, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        recur_layers.append(conv_latent)

        head_conv1 = nn.Conv2d(stoch_classes, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.stoch_classes, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            
            # interim_thought = torch.softmax(interim_thought, 1)
            # interim_thought = torch.sigmoid(interim_thought)
            

            dist = Discrete(logits=interim_thought.permute(0,2,3,1))

            # dist = Discrete2(logits=interim_thought.permute(0,2,3,1))

            # dist = ContinuousBernoulli(logits=interim_thought.permute(0,2,3,1)*2)
            # dist = Bernoulli(logits=interim_thought.permute(0,2,3,1)*2)

            # interim_thought = torch.sigmoid(interim_thought*2)
            # interim_thought = self.adaptive_discrete(interim_thought)


            if self.training:
                interim_thought = dist.sample().permute(0,3,1,2)
            else:
                interim_thought = dist.mode.permute(0,3,1,2)
            

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_2d_small_latent(width, **kwargs):
    return DTNetSmallTest(BasicBlock2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_small_latent20c(width, **kwargs):
    return DTNetSmallTest(BasicBlock2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True,stoch_classes=20)


class AdaptiveDiscrete(nn.Module):
    def __init__(self, nr_params, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.nr_params = nr_params
        self.alpha = nn.Parameter(torch.ones(nr_params,1,1)*-3)

    def forward(self,x):
        alpha = torch.sigmoid(self.alpha)


        a=alpha*0.5
        b=1-a

        mask_a = (x<a).float() #they are always exclusive
        mask_b = (x>b).float() 
        mask_any = mask_a+mask_b

        return mask_any*(mask_b + (x-x.detach()) - (alpha - alpha.detach()))+(1-mask_any)*x






class DTNetSmallTest1Conv(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,stoch_classes=10, **kwargs):
        super().__init__()
        self.name = "DTNetSmallTest1Conv"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        self.stoch_classes = stoch_classes
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
        #                       stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(stoch_classes + in_channels, width, kernel_size=1,
                                stride=1, padding=0, bias=bias)

        conv_latent = nn.Conv2d(width,stoch_classes, kernel_size=1,
                                stride=1, padding=0, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        recur_layers.append(conv_latent)

        head_conv1 = nn.Conv2d(stoch_classes, 32, kernel_size=1,
                               stride=1, padding=0, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=bias)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        


        self.adaptive_discrete = AdaptiveDiscrete(stoch_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.stoch_classes, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)


        
        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            
            # interim_thought = torch.softmax(interim_thought, 1)
            # interim_thought = torch.tanh(interim_thought)

            # interim_thought = F.relu(interim_thought)

            # 
            # interim_thought = torch.sigmoid(interim_thought)
            # interim_thought = self.adaptive_discrete(interim_thought)

            # interim_thought = F.hardsigmoid(interim_thought)
            # dist = OneHotDist(logits=interim_thought.permute(0,2,3,1))

            # dist = ContinuousBernoulli(logits=interim_thought.permute(0,2,3,1))
            # dist = Bernoulli(logits=interim_thought.permute(0,2,3,1))
            dist = Discrete(logits=interim_thought.permute(0,2,3,1))

            # dist = Discrete2(logits=interim_thought.permute(0,2,3,1))

            # dist = ContinuousBernoulli(logits=interim_thought.permute(0,2,3,1)*2)
            # dist = Bernoulli(logits=interim_thought.permute(0,2,3,1)*2)

            # interim_thought = torch.sigmoid(interim_thought*2)
            # interim_thought = self.adaptive_discrete(interim_thought)


            if self.training:
                interim_thought = dist.sample().permute(0,3,1,2)
            else:
                interim_thought = dist.mode.permute(0,3,1,2)
            
            # interim_thought = interim_thought*2-1 #make it -1,1

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_small_latent_1conv_1prop1block(width, **kwargs):
    return DTNetSmallTest1Conv(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_small_latent_1conv_1prop1block_20c(width, **kwargs):
    return DTNetSmallTest1Conv(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True, stoch_classes=20)


def dt_small_latent_1conv_1prop1block_3c(width, **kwargs):
    return DTNetSmallTest1Conv(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True, stoch_classes=2)

def dt_small_latent_1conv_1prop1block_2c(width, **kwargs):
    return DTNetSmallTest1Conv(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True, stoch_classes=2)


def dt_small_latent_1conv_1prop1block_1c(width, **kwargs):
    return DTNetSmallTest1Conv(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True, stoch_classes=1)


class DTNet_Conv1AddNoise(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False, **kwargs):
        super().__init__()

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        # proj_conv = nn.Conv2d(in_channels, width, kernel_size=1,
        #                       stride=1, padding=0, bias=False)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=1,
                                stride=1, padding=0, bias=False)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=1,
                               stride=1, padding=0, bias=False)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=1,
                               stride=1, padding=0, bias=False)

        # self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)
        
        assert self.recall, "Must have recall for this model"

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)


        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    def forward_L1_interim(self, x, iters_to_do, interim_thought=None, return_all_outputs=False,only_final_and_next=False,custom_beta=0, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = torch.zeros(x.shape[0],self.width, x.shape[2], x.shape[3]).to(x.device)

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # previous_interim_thought = interim_thought
        noise_loss = 0
        # noise_interims = [interim_thought]
        
        for i in range(iters_to_do):
            if self.recall:
                ## make x not noisy
                interim_thought_noise_inp = torch.cat([interim_thought + torch.randn_like(interim_thought) * custom_beta, x], 1)
                interim_thought = torch.cat([interim_thought, x], 1)

            interim_thought_noise = self.recur_block(interim_thought_noise_inp)
            interim_thought = self.recur_block(interim_thought)
            
            ## L1 interim
            noise_loss += torch.mean(torch.abs(interim_thought_noise - interim_thought.detach()))

            out = self.head(interim_thought)
            all_outputs[:, i] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, noise_loss
            else:
                return out, interim_thought, noise_loss

        return all_outputs


def dt_1conv_1p1block_addnoise(width, **kwargs):
    return DTNet_Conv1AddNoise(BasicBlock2DPropBegin, [1], width=width, in_channels=kwargs["in_channels"], recall=True)
###

# new implementation
import math
from typing import Iterable

import torch as th
import torch.nn as nn
from torch.nn import Parameter


class LayerNorm(nn.Module):
    """
    Layer Normalization based on Ba & al.:
    'Layer Normalization'
    https://arxiv.org/pdf/1607.06450.pdf
    """

    def __init__(self, input_size: int, learnable: bool = True, epsilon: float = 1e-6):
        super(LayerNorm, self).__init__()
        self.input_size = input_size
        self.learnable = learnable
        self.alpha = th.empty(1, input_size).fill_(0)
        self.beta = th.empty(1, input_size).fill_(0)
        self.epsilon = epsilon
        # Wrap as parameters if necessary
        if learnable:
            self.alpha = Parameter(self.alpha)
            self.beta = Parameter(self.beta)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.input_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x: th.Tensor) -> th.Tensor:
        size = x.size()
        x = x.view(x.size(0), -1)
        x = (x - th.mean(x, 1).unsqueeze(1)) / th.sqrt(th.var(x, 1).unsqueeze(1) + self.epsilon)
        if self.learnable:
            x = self.alpha.expand_as(x) * x + self.beta.expand_as(x)
        return x.view(size)



class LSTM(nn.Module):

    """
    An implementation of Hochreiter & Schmidhuber:
    'Long-Short Term Memory'
    http://www.bioinf.jku.at/publications/older/2604.pdf

    Special args:

    dropout_method: one of
            * pytorch: default dropout implementation
            * gal: uses GalLSTM's dropout
            * moon: uses MoonLSTM's dropout
            * semeniuta: uses SemeniutaLSTM's dropout
    """

    def __init__(
        self, input_size: int, hidden_size: int, bias: bool = True, dropout: float = 0.0, dropout_method: str = "pytorch"
    ):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.dropout = dropout
        self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()
        assert dropout_method.lower() in ["pytorch", "gal", "moon", "semeniuta"]
        self.dropout_method = dropout_method

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size).fill_(keep)).to(device)

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x: th.Tensor, hidden: Tuple[th.Tensor, th.Tensor]) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
        do_dropout = self.training and self.dropout > 0.0
        h, c = hidden
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x = x.view(x.size(1), -1)

        # Linear mappings
        preact = self.i2h(x) + self.h2h(h)

        # activations
        gates = preact[:, : 3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size :].tanh()
        i_t = gates[:, : self.hidden_size]
        f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size :]

        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g_t = F.dropout(g_t, p=self.dropout, training=self.training)

        c_t = th.mul(c, f_t) + th.mul(i_t, g_t)

        if do_dropout and self.dropout_method == "moon":
            c_t.data.set_(th.mul(c_t, self.mask).data)
            c_t.data *= 1.0 / (1.0 - self.dropout)

        h_t = th.mul(o_t, c_t.tanh())

        # Reshape for compatibility
        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_t.data.set_(th.mul(h_t, self.mask).data)
                h_t.data *= 1.0 / (1.0 - self.dropout)

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)
    
class LayerNormLSTM(LSTM):

    """
    Layer Normalization LSTM, based on Ba & al.:
    'Layer Normalization'
    https://arxiv.org/pdf/1607.06450.pdf

    Special args:
        ln_preact: whether to Layer Normalize the pre-activations.
        learnable: whether the LN alpha & gamma should be used.
    """

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool = True,
        dropout: float = 0.0,
        dropout_method: str = "pytorch",
        ln_preact: bool = True,
        learnable: bool = True,
    ):
        super(LayerNormLSTM, self).__init__(
            input_size=input_size, hidden_size=hidden_size, bias=bias, dropout=dropout, dropout_method=dropout_method
        )
        if ln_preact:
            self.ln_i2h = LayerNorm(4 * hidden_size, learnable=learnable)
            self.ln_h2h = LayerNorm(4 * hidden_size, learnable=learnable)
        self.ln_preact = ln_preact
        self.ln_cell = LayerNorm(hidden_size, learnable=learnable)

    def forward(self, x: th.Tensor, hidden: Tuple[th.Tensor, th.Tensor]) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
        do_dropout = self.training and self.dropout > 0.0
        h, c = hidden
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x = x.view(x.size(1), -1)

        # Linear mappings
        i2h = self.i2h(x)
        h2h = self.h2h(h)
        if self.ln_preact:
            i2h = self.ln_i2h(i2h)
            h2h = self.ln_h2h(h2h)
        preact = i2h + h2h

        # activations
        gates = preact[:, : 3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size :].tanh()
        i_t = gates[:, : self.hidden_size]
        f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size :]

        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g_t = F.dropout(g_t, p=self.dropout, training=self.training)

        c_t = th.mul(c, f_t) + th.mul(i_t, g_t)

        if do_dropout and self.dropout_method == "moon":
            c_t.data.set_(th.mul(c_t, self.mask).data)
            c_t.data *= 1.0 / (1.0 - self.dropout)

        c_t = self.ln_cell(c_t)
        h_t = th.mul(o_t, c_t.tanh())

        # Reshape for compatibility
        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_t.data.set_(th.mul(h_t, self.mask).data)
                h_t.data *= 1.0 / (1.0 - self.dropout)

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)


class LayerNormGalLSTM(LayerNormLSTM):

    """
    Mixes GalLSTM's Dropout with Layer Normalization
    """

    def __init__(self, *args, **kwargs):
        super(LayerNormGalLSTM, self).__init__(*args, **kwargs)
        self.dropout_method = "gal"
        self.sample_mask()


from typing import List, Tuple, Type

class MultiLayerLSTM(nn.Module):

    """
    MultiLayer LSTM of any type.

    Note: Dropout is deactivated on the last layer.
    """

    def __init__(self, input_size: int, layer_type: Type[LSTM], layer_sizes: List[int] = (64, 64), *args, **kwargs):
        super(MultiLayerLSTM, self).__init__()
        rnn = layer_type
        self.layers: List[LSTM] = []
        prev_size = input_size
        for size in layer_sizes[:-1]:
            layer = rnn(input_size=prev_size, hidden_size=size, *args, **kwargs)
            self.layers.append(layer)
            prev_size = size
        if "dropout" in kwargs:
            del kwargs["dropout"]
        if len(layer_sizes) > 0:
            layer = rnn(input_size=prev_size, hidden_size=layer_sizes[-1], dropout=0.0, *args, **kwargs)
            self.layers.append(layer)
        self.layer_sizes = layer_sizes
        self.input_size = input_size
        self.params = nn.ModuleList(self.layers)

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def create_hiddens(self, batch_size: int = 1) -> List[Tuple[th.Tensor, th.Tensor]]:
        # Uses Xavier init here.
        hiddens: List[Tuple[th.Tensor, th.Tensor]] = []
        for layer in self.layers:
            std = math.sqrt(2.0 / (layer.input_size + layer.hidden_size))
            hiddens.append(
                (
                    th.empty(1, batch_size, layer.hidden_size).normal_(0, std),
                    th.empty(1, batch_size, layer.hidden_size).normal_(0, std),
                )
            )
        return hiddens

    def sample_mask(self):
        for layer in self.layers:
            layer.sample_mask()

    def forward(
        self, x: th.Tensor, hiddens: Tuple[th.Tensor, th.Tensor]
    ) -> Tuple[th.Tensor, List[Tuple[th.Tensor, th.Tensor]]]:
        new_hiddens: List[Tuple[th.Tensor, th.Tensor]] = []
        for layer, h in zip(self.layers, hiddens):
            x, new_h = layer(x, h)
            new_hiddens.append(new_h)
        return x, new_hiddens


def create_hidden(batch_size,input_size, hidden_size,device, zeros=False) -> List[Tuple[th.Tensor, th.Tensor]]:
    # Uses Xavier init here.

    if zeros:
        return (
            th.zeros(1, batch_size, hidden_size).to(device),
            th.zeros(1, batch_size, hidden_size).to(device),
        )


    std = math.sqrt(2.0 / (input_size + hidden_size))
    return (
            th.empty(1, batch_size, hidden_size).normal_(0, std).to(device),
            th.empty(1, batch_size, hidden_size).normal_(0, std).to(device),
        )

####

class LSTMCellImproved(nn.LSTMCell):
    def __init__(self, *args,unit_forget_bias=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.unit_forget_bias = unit_forget_bias
        # self._init_weights()
    
    def _init_weights(self):
        """
        Use orthogonal init for recurrent layers, xavier uniform for input layers
        Bias is 0 except for forget gate
        """
        for name, param in self.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param.data)
            elif "weight_ih" in name:
                nn.init.xavier_uniform_(param.data)
            elif "bias" in name and self.unit_forget_bias:
                nn.init.zeros_(param.data)
                param.data[self.hidden_size:2 * self.hidden_size] = 1


class SampleDrop(nn.Module):
    """Applies dropout to input samples with a fixed mask."""
    def __init__(self, dropout=0):
        super().__init__()

        assert 0 <= dropout < 1
        self._mask = None
        self._dropout = dropout

    def set_weights(self, X):
        """Calculates a new dropout mask."""
        assert len(X.shape) == 2

        mask = Variable(torch.ones(X.size(0), X.size(1)), requires_grad=False)

        if X.is_cuda:
            mask = mask.cuda()

        ## also scales mask correctly, this is same as GAL dropout
        self._mask = F.dropout(mask, p=self._dropout, training=self.training)

    def forward(self, X):
        """Applies dropout to the input X."""
        if not self.training or not self._dropout:
            return X
        else:
            return X * self._mask
        



class SampleDropND(nn.Module):
    """Applies dropout to input samples with a fixed mask."""
    def __init__(self, dropout=0):
        super().__init__()

        assert 0 <= dropout < 1
        self._mask = None
        self._dropout = dropout

    def set_weights(self, X):
        """Calculates a new dropout mask."""
        # assert len(X.shape) == 4

        mask = Variable(torch.ones_like(X), requires_grad=False)

        if X.is_cuda:
            mask = mask.cuda()

        ## also scales mask correctly, this is same as GAL dropout
        self._mask = F.dropout(mask, p=self._dropout, training=self.training)

    def forward(self, X):
        """Applies dropout to the input X."""
        if not self.training or not self._dropout:
            return X
        else:
            return X * self._mask

class SampleDrop2D(nn.Module):
    """Applies dropout to input samples with a fixed mask."""
    def __init__(self, dropout=0):
        super().__init__()

        assert 0 <= dropout < 1
        self._mask = None
        self._dropout = dropout

    def set_weights(self, X):
        """Calculates a new dropout mask."""
        assert len(X.shape) == 4

        mask = Variable(torch.ones(X.size(0), X.size(1),X.size(2),X.size(3)), requires_grad=False)

        if X.is_cuda:
            mask = mask.cuda()

        ## also scales mask correctly, this is same as GAL dropout
        self._mask = F.dropout(mask, p=self._dropout, training=self.training)

    def forward(self, X):
        """Applies dropout to the input X."""
        if not self.training or not self._dropout:
            return X
        else:
            return X * self._mask

from einops import rearrange

class DTNetLSTM(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "DTNetLSTM"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        self.lstm = LSTMCellImproved(width, width)

        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDrop(dropout=self._dropout_i)
        self._state_drop = SampleDrop(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)


        interim_thought_flat = rearrange(interim_thought, 'b c h w -> (b h w) c')
        
        self._input_drop.set_weights(interim_thought_flat)
        self._state_drop.set_weights(interim_thought_flat)
        
        interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        interim_thought_h = self._state_drop(interim_thought_h)
        interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))

        for i in range(iters_to_do):
            if self.recall:
                interim_thought_new = torch.cat([interim_thought, x], 1)
            else:
                interim_thought_new = interim_thought
            interim_thought_new = self.recur_block(interim_thought_new)
            
            interim_thought_new = rearrange(interim_thought_new, 'b c h w -> (b h w) c')
            interim_thought_h, c = self.lstm(self._input_drop(interim_thought_new),(interim_thought_h,c))
            interim_thought_h = self._state_drop(interim_thought_h)
            interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
            
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_recall_lstm_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_lstm_drop02_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_net_recall_lstm_drop02_1b_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [1], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_net_recall_lstm_drop02_0b_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_net_recall_lstm_drop05_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.5,_dropout_h=0.5)

def dt_net_recall_lstm_drop01_2d(width, **kwargs):
    return DTNetLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)




class DTNetLSTM2(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,init_h_zeros=False,ln_preact=True,dropout_method='gal',**kwargs):
        super().__init__()

        self.name = "DTNetLSTM2"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        self.lstm = LayerNormLSTM(width, width,dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method)

        self._dropout = _dropout
        self.init_h_zeros=init_h_zeros

        # self._input_drop = SampleDrop(dropout=self._dropout_i)
        # self._state_drop = SampleDrop(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)


        interim_thought_flat = rearrange(interim_thought, 'b c h w -> (b h w) c')[None,...]
        
        self.lstm.sample_mask(interim_thought_flat.device)

        h,c = create_hidden(interim_thought_flat.size(0),self.width,self.width,device=interim_thought_flat.device,zeros=self.init_h_zeros)
        
        _,(interim_thought_h, c) = self.lstm(interim_thought_flat, (h,c))
        interim_thought = rearrange(interim_thought_h[0], '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))

        for i in range(iters_to_do):
            if self.recall:
                interim_thought_new = torch.cat([interim_thought, x], 1)
            else:
                interim_thought_new = interim_thought
            interim_thought_new = self.recur_block(interim_thought_new)
            
            interim_thought_new = rearrange(interim_thought_new, 'b c h w -> (b h w) c')[None,...]
            _, (interim_thought_h, c) = self.lstm(interim_thought_new,(interim_thought_h,c))
            interim_thought = rearrange(interim_thought_h[0], '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
            
            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_net_lstm2_ln_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_lstm2_ln_0b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_net_lstm2_ln_zeros_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, init_h_zeros=True)


def dt_net_lstm2_ln_0b_zeros_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True, init_h_zeros=True)


def dt_net_lstm2_ln_drop01_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1)

def dt_net_lstm2_ln_drop02_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2)

def dt_net_lstm2_ln_drop03_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.3)

def dt_net_lstm2_ln_drop05_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5)

def dt_net_lstm2_ln_pydrop01_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1,dropout_method='pytorch')

def dt_net_lstm2_ln_pydrop02_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='pytorch')

def dt_net_lstm2_ln_pydrop03_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.3,dropout_method='pytorch')

def dt_net_lstm2_ln_pydrop05_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5,dropout_method='pytorch')

def dt_net_lstm2_ln_moondrop01_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1,dropout_method='moon')

def dt_net_lstm2_ln_moondrop02_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='moon')

def dt_net_lstm2_ln_moondrop03_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.3,dropout_method='moon')

def dt_net_lstm2_ln_moondrop05_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5,dropout_method='moon')

def dt_net_lstm2_ln_semdrop01_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.1,dropout_method='semeniuta')

def dt_net_lstm2_ln_semdrop02_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.2,dropout_method='semeniuta')

def dt_net_lstm2_ln_semdrop03_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.3,dropout_method='semeniuta')

def dt_net_lstm2_ln_semdrop04_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.4,dropout_method='semeniuta')


def dt_net_lstm2_ln_semdrop05_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.5,dropout_method='semeniuta')


def dt_net_lstm2_ln_semdrop07_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                     recall=True,_dropout=0.7,dropout_method='semeniuta')



def dt_net_lstm2_ln_drop01_0b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1)

def dt_net_lstm2_ln_drop02_0b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2)

def dt_net_lstm2_lnnopre_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,ln_preact=False)

def dt_net_lstm2_lnnopre_0b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True,ln_preact=False)


def dt_net_lstm2_lnnopre_drop01_2b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1,ln_preact=False)

def dt_net_lstm2_lnnopre_drop01_0b_2d(width, **kwargs):
    return DTNetLSTM2(BasicBlock, [0], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.1,ln_preact=False)

## 21 - 40it OOM
## 17 - 80it OOM
# 15 -30
# 9 -40 is useless...

# https://github.com/ndrplz/ConvLSTM_pytorch/blob/master/convlstm.py
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.reset_parameters()

    def forward(self, input_tensor, cur_state=None):
        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
    

    def reset_parameters(self):
        #self.conv.reset_parameters()
        nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
        self.conv.bias.data.zero_()


        # for name, param in self.named_parameters():
        #     if "weight_hh" in name:
        #         nn.init.orthogonal_(param.data)
        #     elif "weight_ih" in name:
        #         nn.init.xavier_uniform_(param.data)
        #     elif "bias" in name and self.unit_forget_bias:
            # 1 to forget bias
        #         nn.init.zeros_(param.data)
        #         param.data[self.hidden_size:2 * self.hidden_size] = 1



class ConvLSTMCellV2(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,
                dropout: float = 0.0,
                dropout_method: str = "pytorch",
                ln_preact: bool = True,
                learnable: bool = True,
                use_instance_norm=False):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellV2, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.hidden_size = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
        #                       out_channels=4 * self.hidden_dim,
        #                       kernel_size=self.kernel_size,
        #                       padding=self.padding,
        #                       bias=self.bias)

        ## atencao conv_i2h e' usado no init como device

        self.conv_i2h = nn.Conv2d(in_channels=self.input_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.conv_h2h = nn.Conv2d(in_channels=self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # 1,channels is equivalent to layernorm

        if use_instance_norm:
            if ln_preact:
                self.ln_i2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.InstanceNorm2d(self.hidden_dim, affine=learnable)


        else:

            if ln_preact:
                self.ln_i2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.GroupNorm(1,self.hidden_dim, affine=learnable)
        self.ln_preact = ln_preact

        self.dropout = dropout
        self.dropout_method = dropout_method

        self.reset_parameters()

    def forward(self, input_tensor, cur_state=None):
        do_dropout = self.training and self.dropout > 0.0


        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        h_cur, c_cur = cur_state

        # combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # combined_conv = self.conv(combined)

        ## FIXME precisamos de separar os mappings
        # # Linear mappings
        i2h = self.conv_i2h(input_tensor)
        h2h = self.conv_h2h(h_cur)
        if self.ln_preact:
            i2h = self.ln_i2h(i2h)
            h2h = self.ln_h2h(h2h)
        combined_conv = i2h + h2h

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        if do_dropout and self.dropout_method == "input":
            cc_i = F.dropout(cc_i, p=self.dropout, training=self.training)


        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)


        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g = F.dropout(g, p=self.dropout, training=self.training)


        c_next = f * c_cur + i * g


        if do_dropout and self.dropout_method == "moon":
            ## be careful about shapes
            c_next.data.set_(th.mul(c_next, self.mask).data)
            c_next.data *= 1.0 / (1.0 - self.dropout)

        c_next = self.ln_cell(c_next)

        h_next = o * torch.tanh(c_next)

        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_next, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_next.data.set_(th.mul(h_next, self.mask).data)
                h_next.data *= 1.0 / (1.0 - self.dropout)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device))
    

    # def reset_parameters(self):
    #     #self.conv.reset_parameters()
    #     nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
    #     self.conv.bias.data.zero_()

    # official init
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_dim)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size, 1,1).fill_(keep)).to(device)



class ConvLSTMCellV2SemDrop(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,
                dropout: float = 0.0,
                dropout_method: str = "pytorch",
                ln_preact: bool = True,
                learnable: bool = True,
                use_instance_norm=False):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellV2SemDrop, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.hidden_size = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
        #                       out_channels=4 * self.hidden_dim,
        #                       kernel_size=self.kernel_size,
        #                       padding=self.padding,
        #                       bias=self.bias)

        ## atencao conv_i2h e' usado no init como device

        self.conv_i2h = nn.Conv2d(in_channels=self.input_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.conv_h2h = nn.Conv2d(in_channels=self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # 1,channels is equivalent to layernorm

        if use_instance_norm:
            if ln_preact:
                self.ln_i2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.InstanceNorm2d(self.hidden_dim, affine=learnable)


        else:

            if ln_preact:
                self.ln_i2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.GroupNorm(1,self.hidden_dim, affine=learnable)
        self.ln_preact = ln_preact

        self.dropout = dropout
        self.dropout_method = dropout_method

        self.reset_parameters()

    def forward(self, input_tensor, cur_state=None):
        do_dropout = self.training and self.dropout > 0.0


        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        h_cur, c_cur = cur_state

        # combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # combined_conv = self.conv(combined)

        ## FIXME precisamos de separar os mappings
        # # Linear mappings
        i2h = self.conv_i2h(input_tensor)
        h2h = self.conv_h2h(h_cur)
        if self.ln_preact:
            i2h = self.ln_i2h(i2h)
            h2h = self.ln_h2h(h2h)
        combined_conv = i2h + h2h

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        if do_dropout and self.dropout_method == "input":
            cc_i = F.dropout(cc_i, p=self.dropout, training=self.training)


        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)


        # cell computations
        if do_dropout:
            g = F.dropout(g, p=self.dropout, training=self.training)


        c_next = f * c_cur + i * g


        if do_dropout and self.dropout_method == "moon":
            ## be careful about shapes
            c_next.data.set_(th.mul(c_next, self.mask).data)
            c_next.data *= 1.0 / (1.0 - self.dropout)

        c_next = self.ln_cell(c_next)

        h_next = o * torch.tanh(c_next)

        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_next, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_next.data.set_(th.mul(h_next, self.mask).data)
                h_next.data *= 1.0 / (1.0 - self.dropout)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device))
    

    # def reset_parameters(self):
    #     #self.conv.reset_parameters()
    #     nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
    #     self.conv.bias.data.zero_()

    # official init
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_dim)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size, 1,1).fill_(keep)).to(device)




class ConvLSTMCellV2_ReduceInput(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,
                dropout: float = 0.0,
                dropout_method: str = "pytorch",
                ln_preact: bool = True,
                learnable: bool = True,
                use_instance_norm=False):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellV2_ReduceInput, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.hidden_size = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
        #                       out_channels=4 * self.hidden_dim,
        #                       kernel_size=self.kernel_size,
        #                       padding=self.padding,
        #                       bias=self.bias)

        ## atencao conv_i2h e' usado no init como device

        self.conv_i2h = nn.Conv2d(in_channels=self.input_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.conv_h2h = nn.Conv2d(in_channels=self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # 1,channels is equivalent to layernorm

        if use_instance_norm:
            if ln_preact:
                self.ln_i2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.InstanceNorm2d(4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.InstanceNorm2d(self.hidden_dim, affine=learnable)


        else:

            if ln_preact:
                self.ln_i2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.GroupNorm(1,self.hidden_dim, affine=learnable)
        self.ln_preact = ln_preact

        self.dropout = dropout
        self.dropout_method = dropout_method

        self.reset_parameters()

    def forward(self, input_tensor, cur_state=None):
        do_dropout = self.training and self.dropout > 0.0


        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        h_cur, c_cur = cur_state

        # combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # combined_conv = self.conv(combined)

        ## FIXME precisamos de separar os mappings
        # # Linear mappings
        i2h = self.conv_i2h(input_tensor)
        h2h = self.conv_h2h(h_cur)
        if self.ln_preact:
            i2h = self.ln_i2h(i2h)
            h2h = self.ln_h2h(h2h)
        combined_conv = i2h + h2h

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)


        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g = F.dropout(g, p=self.dropout, training=self.training)


        c_next = f * c_cur + i * g


        if do_dropout and self.dropout_method == "moon":
            ## be careful about shapes
            c_next.data.set_(th.mul(c_next, self.mask).data)
            c_next.data *= 1.0 / (1.0 - self.dropout)

        c_next = self.ln_cell(c_next)

        h_next = o * torch.tanh(c_next)

        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_next, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_next.data.set_(th.mul(h_next, self.mask).data)
                h_next.data *= 1.0 / (1.0 - self.dropout)

        return h_next, c_next, cc_i

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device))
    

    # def reset_parameters(self):
    #     #self.conv.reset_parameters()
    #     nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
    #     self.conv.bias.data.zero_()

    # official init
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_dim)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size, 1,1).fill_(keep)).to(device)



## FIXME a idea e' ser otimizado...
class ConvLSTMCellV3(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,
                dropout: float = 0.0,
                dropout_method: str = "pytorch",
                ln_preact: bool = True,
                learnable: bool = True,
                use_instance_norm=False,
                conv_dim: int = 2):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellV3, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.hidden_size = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
        #                       out_channels=4 * self.hidden_dim,
        #                       kernel_size=self.kernel_size,
        #                       padding=self.padding,
        #                       bias=self.bias)

        ## atencao conv_i2h e' usado no init como device

        # group norm works for both 1d and 2d

        self.conv_dim = conv_dim

        if conv_dim == 2:
            conv_class = nn.Conv2d
            instance_norm_class = nn.InstanceNorm2d
        elif conv_dim == 1:
            conv_class = nn.Conv1d
            instance_norm_class = nn.InstanceNorm1d

            self.kernel_size = self.kernel_size[0]
            self.padding = self.padding[0]

        else:
            raise ValueError("conv_dim must be 1 or 2")

        self.conv_i2h = conv_class(in_channels=self.input_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.conv_h2h = conv_class(in_channels=self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # 1,channels is equivalent to layernorm

        if use_instance_norm:
            if ln_preact:
                self.ln_i2h = instance_norm_class(4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = instance_norm_class(4 * self.hidden_dim, affine=learnable)
            self.ln_cell = instance_norm_class(self.hidden_dim, affine=learnable)


        else:

            if ln_preact:
                self.ln_i2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
                self.ln_h2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
            self.ln_cell = nn.GroupNorm(1,self.hidden_dim, affine=learnable)
        self.ln_preact = ln_preact

        self.dropout = dropout
        self.dropout_method = dropout_method

        self.reset_parameters()

    def forward_input(self,input_tensor):
        i2h = self.conv_i2h(input_tensor)
        if self.ln_preact:
            i2h = self.ln_i2h(i2h)
        return i2h

    def forward(self, i2h, cur_state=None):
        do_dropout = self.training and self.dropout > 0.0


        if cur_state is None:
            if self.conv_dim == 2:
                cur_state = self.init_hidden(i2h.shape[0],(i2h.shape[2],i2h.shape[3]))
            elif self.conv_dim == 1:
                cur_state = self.init_hidden(i2h.shape[0],(i2h.shape[2],))
        
        h_cur, c_cur = cur_state

        # combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # combined_conv = self.conv(combined)

        ## FIXME precisamos de separar os mappings
        # # Linear mappings
        # i2h = self.conv_i2h(input_tensor)
        h2h = self.conv_h2h(h_cur)
        if self.ln_preact:
            # i2h = self.ln_i2h(i2h)
            h2h = self.ln_h2h(h2h)
        combined_conv = i2h + h2h

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        if do_dropout and self.dropout_method == "input":
            cc_i = F.dropout(cc_i, p=self.dropout, training=self.training)


        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)


        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g = F.dropout(g, p=self.dropout, training=self.training)


        c_next = f * c_cur + i * g


        if do_dropout and self.dropout_method == "moon":
            ## be careful about shapes
            c_next.data.set_(th.mul(c_next, self.mask).data)
            c_next.data *= 1.0 / (1.0 - self.dropout)

        c_next = self.ln_cell(c_next)

        h_next = o * torch.tanh(c_next)

        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_next, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_next.data.set_(th.mul(h_next, self.mask).data)
                h_next.data *= 1.0 / (1.0 - self.dropout)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        if self.conv_dim == 2:
            height, width = image_size
            return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device),
                    torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device))
        elif self.conv_dim == 1:
            height = image_size[0]
            return (torch.zeros(batch_size, self.hidden_dim, height, device=self.conv_i2h.weight.device),
                    torch.zeros(batch_size, self.hidden_dim, height, device=self.conv_i2h.weight.device))
        else:
            raise ValueError("conv_dim must be 1 or 2")
    

    # def reset_parameters(self):
    #     #self.conv.reset_parameters()
    #     nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
    #     self.conv.bias.data.zero_()

    # official init
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_dim)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size, 1,1).fill_(keep)).to(device)




class ConvLSTMCellV3_NoLayerNormalization(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias,
                dropout: float = 0.0,
                dropout_method: str = "pytorch",
                ln_preact: bool = True,
                learnable: bool = True,
                use_instance_norm=False,
                conv_dim: int = 2):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellV3_NoLayerNormalization, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.hidden_size = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        # self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
        #                       out_channels=4 * self.hidden_dim,
        #                       kernel_size=self.kernel_size,
        #                       padding=self.padding,
        #                       bias=self.bias)

        ## atencao conv_i2h e' usado no init como device

        # group norm works for both 1d and 2d

        self.conv_dim = conv_dim

        if conv_dim == 2:
            conv_class = nn.Conv2d
            instance_norm_class = nn.InstanceNorm2d
        elif conv_dim == 1:
            conv_class = nn.Conv1d
            instance_norm_class = nn.InstanceNorm1d

            self.kernel_size = self.kernel_size[0]
            self.padding = self.padding[0]

        else:
            raise ValueError("conv_dim must be 1 or 2")

        self.conv_i2h = conv_class(in_channels=self.input_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        self.conv_h2h = conv_class(in_channels=self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # 1,channels is equivalent to layernorm

        # if use_instance_norm:
        #     if ln_preact:
        #         self.ln_i2h = instance_norm_class(4 * self.hidden_dim, affine=learnable)
        #         self.ln_h2h = instance_norm_class(4 * self.hidden_dim, affine=learnable)
        #     self.ln_cell = instance_norm_class(self.hidden_dim, affine=learnable)


        # else:

        #     if ln_preact:
        #         self.ln_i2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
        #         self.ln_h2h = nn.GroupNorm(1,4 * self.hidden_dim, affine=learnable)
        #     self.ln_cell = nn.GroupNorm(1,self.hidden_dim, affine=learnable)
        # self.ln_preact = ln_preact

        self.dropout = dropout
        self.dropout_method = dropout_method

        self.reset_parameters()

    def forward_input(self,input_tensor):
        i2h = self.conv_i2h(input_tensor)
        # if self.ln_preact:
        #     i2h = self.ln_i2h(i2h)
        return i2h

    def forward(self, i2h, cur_state=None):
        do_dropout = self.training and self.dropout > 0.0


        if cur_state is None:
            if self.conv_dim == 2:
                cur_state = self.init_hidden(i2h.shape[0],(i2h.shape[2],i2h.shape[3]))
            elif self.conv_dim == 1:
                cur_state = self.init_hidden(i2h.shape[0],(i2h.shape[2],))
        
        h_cur, c_cur = cur_state

        # combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        # combined_conv = self.conv(combined)

        ## FIXME precisamos de separar os mappings
        # # Linear mappings
        # i2h = self.conv_i2h(input_tensor)
        h2h = self.conv_h2h(h_cur)
        # if self.ln_preact:
        #     # i2h = self.ln_i2h(i2h)
        #     h2h = self.ln_h2h(h2h)
        combined_conv = i2h + h2h

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        if do_dropout and self.dropout_method == "input":
            cc_i = F.dropout(cc_i, p=self.dropout, training=self.training)


        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)


        # cell computations
        if do_dropout and self.dropout_method == "semeniuta":
            g = F.dropout(g, p=self.dropout, training=self.training)


        c_next = f * c_cur + i * g


        if do_dropout and self.dropout_method == "moon":
            ## be careful about shapes
            c_next.data.set_(th.mul(c_next, self.mask).data)
            c_next.data *= 1.0 / (1.0 - self.dropout)

        # c_next = self.ln_cell(c_next)

        h_next = o * torch.tanh(c_next)

        if do_dropout:
            if self.dropout_method == "pytorch":
                F.dropout(h_next, p=self.dropout, training=self.training, inplace=True)
            if self.dropout_method == "gal":
                h_next.data.set_(th.mul(h_next, self.mask).data)
                h_next.data *= 1.0 / (1.0 - self.dropout)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        if self.conv_dim == 2:
            height, width = image_size
            return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device),
                    torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv_i2h.weight.device))
        elif self.conv_dim == 1:
            height = image_size[0]
            return (torch.zeros(batch_size, self.hidden_dim, height, device=self.conv_i2h.weight.device),
                    torch.zeros(batch_size, self.hidden_dim, height, device=self.conv_i2h.weight.device))
        else:
            raise ValueError("conv_dim must be 1 or 2")
    

    # def reset_parameters(self):
    #     #self.conv.reset_parameters()
    #     nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
    #     self.conv.bias.data.zero_()

    # official init
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_dim)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def sample_mask(self,device='cpu'):
        keep = 1.0 - self.dropout
        self.mask = th.bernoulli(th.empty(1, self.hidden_size, 1,1).fill_(keep)).to(device)



class ConvGRUCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize the ConvLSTM cell
        :param input_size: (int, int)
            Height and width of input tensor as (height, width).
        :param input_dim: int
            Number of channels of input tensor.
        :param hidden_dim: int
            Number of channels of hidden state.
        :param kernel_size: (int, int)
            Size of the convolutional kernel.
        :param bias: bool
            Whether or not to add the bias.
        :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
            Whether or not to use cuda.
        """
        super(ConvGRUCell, self).__init__()
        # self.height, self.width = input_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.hidden_dim = hidden_dim
        self.bias = bias
        # self.dtype = dtype

        self.conv_gates = nn.Conv2d(in_channels=input_dim + hidden_dim,
                                    out_channels=2*self.hidden_dim,  # for update_gate,reset_gate respectively
                                    kernel_size=kernel_size,
                                    padding=self.padding,
                                    bias=self.bias)

        self.conv_can = nn.Conv2d(in_channels=input_dim+hidden_dim,
                              out_channels=self.hidden_dim, # for candidate neural memory
                              kernel_size=kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def init_hidden(self, batch_size,image_size,device):
        height, width = image_size
        return torch.zeros(batch_size, self.hidden_dim, height, width).to(device)

    def forward(self, input_tensor, h_cur):
        """

        :param self:
        :param input_tensor: (b, c, h, w)
            input is actually the target_model
        :param h_cur: (b, c_hidden, h, w)
            current hidden and cell states respectively
        :return: h_next,
            next hidden state
        """
        if h_cur is None:
            h_cur = self.init_hidden(input_tensor.size(0), (input_tensor.size(2), input_tensor.size(3)),input_tensor.device)

        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv_gates(combined)

        gamma, beta = torch.split(combined_conv, self.hidden_dim, dim=1)
        reset_gate = torch.sigmoid(gamma)
        update_gate = torch.sigmoid(beta)

        combined = torch.cat([input_tensor, reset_gate*h_cur], dim=1)
        cc_cnm = self.conv_can(combined)
        cnm = torch.tanh(cc_cnm)

        h_next = (1 - update_gate) * h_cur + update_gate * cnm
        return h_next,beta




# ### https://github.com/KL4805/ConvLSTM-Pytorch
# ## acording to forecast paper (based on ndrplz)
# ## peephole implementation
# class ConvLSTMCell(nn.Module):

#     def __init__(self, input_dim, hidden_dim, kernel_size, bias):
#         """
#         Initialize ConvLSTM cell.
#         Parameters
#         ----------
#         input_dim: int
#             Number of channels of input tensor.
#         hidden_dim: int
#             Number of channels of hidden state.
#         kernel_size: (int, int)
#             Size of the convolutional kernel.
#         bias: bool
#             Whether or not to add the bias.
#         """

#         super(ConvLSTMCell, self).__init__()

#         self.input_dim = input_dim
#         self.hidden_dim = hidden_dim

#         self.kernel_size = kernel_size
#         self.padding = kernel_size[0] // 2, kernel_size[1] // 2
#         self.bias = bias

#         self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
#                               out_channels=4 * self.hidden_dim,
#                               kernel_size=self.kernel_size,
#                               padding=self.padding,
#                               bias=self.bias)
#         self.conv_c = nn.Conv2d(in_channels = self.hidden_dim, 
#                                out_channels = 2 * self.hidden_dim, 
#                                kernel_size = self.kernel_size, 
#                                padding = self.padding, 
#                                bias = self.bias)
#         self.conv_cnext = nn.Conv2d(in_channels = self.hidden_dim,
#                                     out_channels = self.hidden_dim, 
#                                     kernel_size = self.kernel_size, 
#                                     padding = self.padding, 
#                                     bias = self.bias)

#     def forward(self, input_tensor, cur_state):
#         # cur_state is passed in as a parameter, not a member variable. 
#         # this now fully simulates the dynamics mentioned in the paper. 
#         h_cur, c_cur = cur_state

#         combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

#         combined_conv = self.conv(combined)
#         cprev_conv = self.conv_c(c_cur)
#         cprev_i, cprev_f = torch.split(cprev_conv, self.hidden_dim, dim = 1)
#         cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
#         i = torch.sigmoid(cc_i + cprev_i)
#         f = torch.sigmoid(cc_f + cprev_f)
#         g = torch.tanh(cc_g)
#         c_next = f * c_cur + i * g
#         cnext_conv = self.conv_cnext(c_next) # pode ser problematico, pk este passo faz mais 1 conv

#         o = torch.sigmoid(cc_o + cnext_conv)
#         h_next = o * torch.tanh(c_next)

#         return h_next, c_next

#     def init_hidden(self, batch_size, image_size):
#         height, width = image_size
#         return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
#                 torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class SampleDropConv(nn.Module):
    """Applies dropout to input samples with a fixed mask."""
    def __init__(self, dropout=0):
        super().__init__()

        assert 0 <= dropout < 1
        self._mask = None
        self._dropout = dropout

    def set_weights(self, X):
        """Calculates a new dropout mask."""
        assert len(X.shape) == 4

        mask = Variable(torch.ones(X.size(0), X.size(1),X.size(2),X.size(3)), requires_grad=False)

        if X.is_cuda:
            mask = mask.cuda()

        self._mask = F.dropout(mask, p=self._dropout, training=self.training)

    def forward(self, X):
        """Applies dropout to the input X."""
        if not self.training or not self._dropout:
            return X
        else:
            return X * self._mask
        


# 6 times slower than conv net
# mas aprende, enquanto a conv nao aprende tao rapido no 1 batch test

### em suma, nao me apercebi de grandes diferencas na inicializacao
### (apenas talvez com +1 possa ser mais estavel? poucas seeds para hipotese)
## 3 lstm layers aprende mais rapido que conv net, apesar de correr mais devagar e usar mais parametros (4*width==4*convs)

## fix dropout
class NetConvLSTM(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(width + in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(width, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(width, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (interim_thought,c2)

            interim_thought, c2 = self.lstm2(self._input_drop(interim_thought),state2)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (interim_thought,c3)
            
            interim_thought, c3 = self.lstm3(self._input_drop(interim_thought),state3)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_convlstm_2d(width, **kwargs):
    return NetConvLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_drop02_2d(width, **kwargs):
    return NetConvLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_convlstm_drop01_2d(width, **kwargs):
    return NetConvLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)

def dt_convlstm_drop05_2d(width, **kwargs):
    return NetConvLSTM(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.5,_dropout_h=0.5)


## fix dropout
class NetConvLSTM2(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(width + in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(width+ in_channels, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(width+ in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (interim_thought,c2)

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c2 = self.lstm2(interim_thought_new,state2)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (interim_thought,c3)

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c3 = self.lstm3(interim_thought_new,state3)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
def dt_convlstm_allconcat_2d(width, **kwargs):
    return NetConvLSTM2(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)


## fix dropout
class NetConvLSTM3(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,mul=5,**kwargs):
        super().__init__()
        
        self.name = "NetConvLSTM_1layer"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(width + in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(width+ in_channels, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(width+ in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

        self.mul = mul
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=self.mul

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
def dt_convlstm_1lstm_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_1lstm_drop01_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.1,_dropout_h=0.1)


def dt_convlstm_1lst_mul6_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        mul=6)



def dt_convlstm_1lstm_drop02_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.2,_dropout_h=0.2)

def dt_convlstm_1lstm_drop03_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.3,_dropout_h=0.3)

def dt_convlstm_1lstm_drop04_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.4,_dropout_h=0.4)


def dt_convlstm_1lstm_drop01_mul6_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.1,_dropout_h=0.1,mul=6)


def dt_convlstm_1lstm_drop02_mul6_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.2,_dropout_h=0.2,mul=6)


def dt_convlstm_1lstm_drop03_mul6_2d(width, **kwargs):
    return NetConvLSTM3(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.3,_dropout_h=0.3,mul=6)


class ConvLSTMCell_2(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell_2, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        
        # self.reset_parameters()

    def forward(self, input_tensor, cur_state=None):
        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f+1) # we could do bias unit init
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
    

    def reset_parameters(self):
        #self.conv.reset_parameters()
        nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain('tanh'))
        self.conv.bias.data.zero_()


class NetConvLSTM4(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell_2(width + in_channels, width, (3,3), True)
        # self.lstm2 = ConvLSTMCell(width+ in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(width+ in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
def dt_convlstm_1lstm_py1bias_2d(width, **kwargs):
    return NetConvLSTM4(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)




class ConvLSTMCellPeep(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.
        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCellPeep, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)
        self.conv_c = nn.Conv2d(in_channels = self.hidden_dim, 
                               out_channels = 2 * self.hidden_dim, 
                               kernel_size = self.kernel_size, 
                               padding = self.padding, 
                               bias = self.bias)
        self.conv_cnext = nn.Conv2d(in_channels = self.hidden_dim,
                                    out_channels = self.hidden_dim, 
                                    kernel_size = self.kernel_size, 
                                    padding = self.padding, 
                                    bias = self.bias)

    def forward(self, input_tensor, cur_state):
        # cur_state is passed in as a parameter, not a member variable. 
        # this now fully simulates the dynamics mentioned in the paper. 
        
        if cur_state is None:
            cur_state = self.init_hidden(input_tensor.shape[0],(input_tensor.shape[2],input_tensor.shape[3]))
        
        
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cprev_conv = self.conv_c(c_cur)
        cprev_i, cprev_f = torch.split(cprev_conv, self.hidden_dim, dim = 1)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i + cprev_i)
        f = torch.sigmoid(cc_f + cprev_f)
        g = torch.tanh(cc_g)
        c_next = f * c_cur + i * g
        cnext_conv = self.conv_cnext(c_next) # pode ser problematico, pk este passo faz mais 1 conv

        o = torch.sigmoid(cc_o + cnext_conv)
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))



class NetConvLSTM5(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellPeep(width + in_channels, width, (3,3), True)
        # self.lstm2 = ConvLSTMCell(width+ in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(width+ in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            interim_thought, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
def dt_convlstm_1lstm_peep_2d(width, **kwargs):
    return NetConvLSTM5(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_1lstm_peep_drop01_2d(width, **kwargs):
    return NetConvLSTM5(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout_i=0.1,_dropout_h=0.1)



## fix dropout
class NetConvLSTM6(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM6"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(width + in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(width, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(width, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (h1,c)
                

            if self.recall:
                interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            else:
                assert False, "not implemented"

            h1, c = self.lstm(interim_thought_new,state)
            interim_thought = self._state_drop(h1)

            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (h2,c2)

            h2, c2 = self.lstm2(self._input_drop(interim_thought),state2)
            interim_thought = self._state_drop(h2)

            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (h3,c3)
            
            h3, c3 = self.lstm3(self._input_drop(interim_thought),state3)
            interim_thought = self._state_drop(h3)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    
def dt_convlstm_diff_h_2d(width, **kwargs):
    return NetConvLSTM6(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_diff_h_drop02_2d(width, **kwargs):
    return NetConvLSTM6(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_convlstm_diff_h_drop01_2d(width, **kwargs):
    return NetConvLSTM6(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)

## fix dropout
class NetConvLSTM7(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM7"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (interim_thought,c2)

            interim_thought, c2 = self.lstm2(x,state2)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (interim_thought,c3)
            
            interim_thought, c3 = self.lstm3(x,state3)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_convlstm_onlyx_diff_c_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_onlyx_diff_c_drop01_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)

def dt_convlstm_onlyx_diff_c_drop02_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.2)

def dt_convlstm_onlyx_diff_c_drop03_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.3)

def dt_convlstm_onlyx_diff_c_drop04_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.4)
def dt_convlstm_onlyx_diff_c_drop05_2d(width, **kwargs):
    return NetConvLSTM7(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.5)


## fix dropout
class NetConvLSTM8(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM8"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(in_channels, width, (3,3), True)
        self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2

        for i in range(iters_to_do*mul):
            if i==0:
                self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            # if i==0:
            #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # else:
            #     # should it be state drop from other lstm though?
            state2 = (interim_thought,c)

            interim_thought, c = self.lstm2(x,state2)
            interim_thought = self._state_drop(interim_thought)

            # if i==0:
            #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # else:
            state3 = (interim_thought,c)
            
            interim_thought, c = self.lstm3(x,state3)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_convlstm_onlyx_same_c_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_onlyx_same_c_drop01_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)

def dt_convlstm_onlyx_same_c_drop02_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_convlstm_onlyx_same_c_drop03_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.3)

def dt_convlstm_onlyx_same_c_drop04_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.4)

def dt_convlstm_onlyx_same_c_drop05_2d(width, **kwargs):
    return NetConvLSTM8(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.5)


## fix dropout
class NetConvLSTM9(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout_i=0,_dropout_h=0,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM9"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_i = _dropout_i
        self._dropout_h = _dropout_h

        # self._input_drop = SampleDropConv(dropout=self._dropout_i)
        self._state_drop = SampleDropConv(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=6 # para ser equivalente a 3 lstms

        self.lstm.sample_mask(interim_thought.device)

        for i in range(iters_to_do*mul):
            if i==0:
                # self._input_drop.set_weights(interim_thought)
                self._state_drop.set_weights(interim_thought)

                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def dt_convlstm_onlyx_1layer_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_onlyx_1layer_drop01_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.1,_dropout_h=0.1)

def dt_convlstm_onlyx_1layer_drop02_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_i=0.2,_dropout_h=0.2)

def dt_convlstm_onlyx_1layer_drop03_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.3)

def dt_convlstm_onlyx_1layer_drop04_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.4)

def dt_convlstm_onlyx_1layer_drop05_2d(width, **kwargs):
    return NetConvLSTM9(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_h=0.5)

## fix dropout
class NetConvLSTM_LN(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,
                  lstm_class=ConvLSTMCellV3,
                  conv_dim: int = 2,
                  output_size=2,flatten=False,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm

        if conv_dim==2:
            conv_class = nn.Conv2d
        elif conv_dim==1:
            conv_class = nn.Conv1d
            in_channels=1 # yeah donno why... thanks original coders
        else:
            assert False, "not implemented"


        proj_conv = conv_class(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions
        if conv_dim==2:
            head_conv1 = conv_class(width, 32, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv2 = conv_class(32, 8, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv3 = conv_class(8, output_size, kernel_size=3,
                                stride=1, padding=1, bias=bias)

            if output_size!=2:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)


        elif conv_dim==1:
            head_conv1 = conv_class(width, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = lstm_class(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine,
                                   conv_dim=conv_dim)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDropND(dropout=self._dropout_h)

        self.output_size = output_size

        assert ConvLSTMCellV3 == lstm_class

        self.flatten = flatten

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
            # assert False, "not implemented"
            # interim_thought = interim_thought
        if len(x.shape)==4:
            if self.flatten:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size*x.size(2)*x.size(3))).to(x.device)
            else:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)
        elif len(x.shape)==3:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2))).to(x.device)
        else:
            assert False, "not implemented"
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(interim_thought.device)

        lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(lstm_inp1,state)
            interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                if self.flatten:
                    out = out.flatten(1)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

    def forward_return_hidden(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
            # assert False, "not implemented"
            # interim_thought = interim_thought
        if len(x.shape)==4:
            if self.flatten:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size*x.size(2)*x.size(3))).to(x.device)
            else:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)

            all_hidden = torch.zeros((x.size(0), iters_to_do, self.width, x.size(2), x.size(3))).to(x.device)
        elif len(x.shape)==3:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2))).to(x.device)
            raise NotImplementedError
        else:
            assert False, "not implemented"
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(interim_thought.device)

        lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(lstm_inp1,state)
            interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                if self.flatten:
                    out = out.flatten(1)
                all_outputs[:, i//mul] = out
                all_hidden[:, i//mul] = interim_thought

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        return all_outputs, all_hidden


def dt_convlstm_ln_nopre_onlyx_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,dropout_method='pytorch')


def convlstm_ln_1l_sgal04_py03_noreduce_4out_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    output_size=4)

def convlstm_ln_1l_sgal04_py03_flatten_4out_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    output_size=4,flatten=True)

def convlstm_ln_1l_noreduce_5out_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                    #  _dropout_gal2=0.4,
                     norm_affine=False,
                    # _dropout=0.3,dropout_method='pytorch',
                    output_size=5)

def convlstm_ln_1l_sgal04_py03_noreduce_5out_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    output_size=5)

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_py03_1d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1)


def dt_convlstm_ln_onlyx_1layer_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_ln_nopre_onlyx_1layer_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          norm_affine=False)


def dt_convlstm_instn_onlyx_1layer_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True)

def dt_convlstm_instn_onlyx_1layer_pydrop02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='pytorch',use_instance_norm=True)

def dt_convlstm_instn_onlyx_1layer_pydrop05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5,dropout_method='pytorch',use_instance_norm=True)


def dt_convlstm_instn_onlyx_1layer_gal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='gal',use_instance_norm=True)

def dt_convlstm_ln_onlyx_1layer_pydrop02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='pytorch')

def dt_convlstm_ln_onlyx_1layer_gal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='gal')


def dt_convlstm_ln_onlyx_1layer_gal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.3,dropout_method='gal')

def dt_convlstm_ln_onlyx_1layer_gal05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5,dropout_method='gal')


def dt_convlstm_ln_onlyx_1layer_moon02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='moon')

def dt_convlstm_ln_onlyx_1layer_moon05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.5,dropout_method='moon')

def dt_convlstm_ln_onlyx_1layer_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='semeniuta')

def dt_convlstm_ln_onlyx_1l_samegal01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1)

def dt_convlstm_ln_onlyx_1l_samegal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2)
def dt_convlstm_ln_onlyx_1l_samegal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3)

def dt_convlstm_ln_onlyx_1l_samegal05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.5)


def dt_convlstm_ln_nopre_onlyx_1l_sgal01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.5,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal07_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.7,norm_affine=False)

def dt_convlstm_ln_nopre_onlyx_1l_sgal08_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.8,norm_affine=False)



################3

def dt_convlstm_ln_nopre_onlyx_1l_sgal01_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')


def dt_convlstm_ln_nopre_onlyx_1l_sgal01_moon01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1,norm_affine=False,
                    _dropout=0.1,dropout_method='moon')

def dt_convlstm_ln_nopre_onlyx_1l_sgal02_moon02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False,
                    _dropout=0.2,dropout_method='moon')

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_moon03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.3,dropout_method='moon')

def dt_convlstm_ln_nopre_onlyx_1l_sgal01_sem01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1,norm_affine=False,
                    _dropout=0.1,dropout_method='semeniuta')

def dt_convlstm_ln_nopre_onlyx_1l_sgal02_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False,
                    _dropout=0.2,dropout_method='semeniuta')

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_sem03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.3,dropout_method='semeniuta')

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_sem04_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.4,dropout_method='semeniuta')

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='semeniuta')

def dt_convlstm_ln_nopre_onlyx_1l_sgal01_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.1,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal02_py02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_py03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_py05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.5,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_py02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch')



def dt_convlstm_ln_nopre_onlyx_1l_sgal04_py04_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.4,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_1l_sgal05_py05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.5,norm_affine=False,
                    _dropout=0.5,dropout_method='pytorch')




def dt_convlstm_ln_nopre_onlyx_1layer_sem01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],recall=True,
                        norm_affine=False,
                        dropout_method='semeniuta',_dropout=0.1)

def dt_convlstm_ln_nopre_onlyx_1layer_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],recall=True,
                        norm_affine=False,
                        dropout_method='semeniuta',_dropout=0.2)

def dt_convlstm_ln_nopre_onlyx_1layer_sem03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],recall=True,
                        norm_affine=False,
                        dropout_method='semeniuta',_dropout=0.3)

def dt_convlstm_ln_nopre_onlyx_1layer_moon01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='moon',_dropout=0.1)

def dt_convlstm_ln_nopre_onlyx_1layer_moon02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='moon',_dropout=0.2)

def dt_convlstm_ln_nopre_onlyx_1layer_moon03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='moon',_dropout=0.3)



def dt_convlstm_ln_nopre_onlyx_1layer_gal01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],recall=True,
                        norm_affine=False,
                        dropout_method='gal',_dropout=0.1)

def dt_convlstm_ln_nopre_onlyx_1layer_gal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='gal',_dropout=0.2)

def dt_convlstm_ln_nopre_onlyx_1layer_gal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='gal',_dropout=0.3)




def dt_convlstm_ln_nopre_onlyx_1layer_pydrop01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],recall=True,
                        norm_affine=False,
                        dropout_method='pytorch',_dropout=0.1)

def dt_convlstm_ln_nopre_onlyx_1layer_pydrop02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='pytorch',_dropout=0.2)

def dt_convlstm_ln_nopre_onlyx_1layer_pydrop03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"],
                        norm_affine=False,
                        dropout_method='pytorch',_dropout=0.3)



##########333
def dt_convlstm_instn_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,use_instance_norm=True)

def dt_convlstm_instn_onlyx_1l_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,use_instance_norm=True)


def dt_convlstm_in_noaf_nopre_onlyx_1l_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,ln_preact=False)

def dt_convlstm_in_noaff_onlyx_1l_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False)

def dt_convlstm_in_noaff_onlyx_1l_sgal01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.1,use_instance_norm=True,norm_affine=False)


def dt_convlstm_in_noaff_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.2,use_instance_norm=True,norm_affine=False)

def dt_convlstm_in_noaff_onlyx_1l_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.3,use_instance_norm=True,norm_affine=False)


def dt_convlstm_in_noaff_onlyx_1l_sgal03_ipdrop_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.3,use_instance_norm=True,norm_affine=False,
                          _dropout=0.3,dropout_method='input')


def dt_convlstm_in_noaff_onlyx_1l_sgal04_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.4,use_instance_norm=True,norm_affine=False)


def dt_convlstm_in_noaff_onlyx_1l_sgal05_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.5,use_instance_norm=True,norm_affine=False)

def dt_convlstm_in_noaff_onlyx_1l_sgal07_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.7,use_instance_norm=True,norm_affine=False)


def dt_convlstm_in_noaff_onlyx_1l_sgal01_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.1,use_instance_norm=True,norm_affine=False,
                            _dropout=0.1,dropout_method='pytorch')

def dt_convlstm_in_noaff_onlyx_1l_sgal01_sem01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.1,use_instance_norm=True,norm_affine=False,
                            _dropout=0.1,dropout_method='semeniuta')


def dt_convlstm_in_noaff_onlyx_1l_sgal02_py02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.2,use_instance_norm=True,norm_affine=False,
                            _dropout=0.2,dropout_method='pytorch')

def dt_convlstm_in_noaff_onlyx_1l_sgal04_py04_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.4,use_instance_norm=True,norm_affine=False,
                            _dropout=0.4,dropout_method='pytorch')


### sem + other drop
def dt_convlstm_ln_nopre_onlyx_1l_sgal03_sem_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_sem_py02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

def dt_convlstm_ln_nopre_onlyx_1l_sgal03_sem_py03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_sem_py01_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_sem_py02_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

def dt_convlstm_ln_nopre_onlyx_1l_sgal04_sem_py03_2d(width, **kwargs):
    return NetConvLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    lstm_class=ConvLSTMCellV2SemDrop)

## fix dropout
class NetConvLSTM_LN_Reduce_Input(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_Reduce_Input"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellV2_ReduceInput(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDrop2D(dropout=self._dropout_h)

        self.output_size = 2 #FIXME


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=6 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(interim_thought.device)

        total_ci = 0

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c, c_i = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            if self.training:
                total_ci += F.binary_cross_entropy_with_logits(c_i,torch.zeros_like(c_i).to(c_i.device))

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        total_ci /= iters_to_do*mul

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, total_ci
            else:
                return out, interim_thought, total_ci

        return all_outputs


def dt_convlstm_iloss_in_noaff_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.2,use_instance_norm=True,norm_affine=False)

def dt_convlstm_iloss_ln_noaff_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False)


def dt_convlstm_iloss_in_noaff_nopre_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          _dropout_gal2=0.2,use_instance_norm=True,norm_affine=False,ln_preact=False)

def dt_convlstm_iloss_ln_noaff_nopre_onlyx_1l_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,norm_affine=False,ln_preact=False)


## fix dropout
class NetConvGRU_Reduce_Input(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,no_aux_loss=False,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_Reduce_Input"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        self.no_aux_loss = no_aux_loss

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        # self.lstm = ConvLSTMCellV2_ReduceInput(in_channels, width, (3,3), True,
        #                            dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                            use_instance_norm=use_instance_norm,learnable=norm_affine)

        self.lstm = ConvGRUCell(in_channels, width, (3,3), True,)

        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDrop2D(dropout=self._dropout_h)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=6 # para ser equivalente a 3 lstms
        # self.lstm.sample_mask(interim_thought.device)

        total_ci = 0

        for i in range(iters_to_do*mul):
            if i==0:
                # self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = interim_thought
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c_i = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            if self.training:
                total_ci += F.binary_cross_entropy_with_logits(c_i,torch.zeros_like(c_i).to(c_i.device))

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        total_ci /= iters_to_do*mul

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought, total_ci
            elif self.no_aux_loss:
                return out, interim_thought
            else:
                return out, interim_thought, total_ci

        return all_outputs


def dt_convgru_iloss(width, **kwargs):
    return NetConvGRU_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,)

def dt_convgru(width, **kwargs):
    return NetConvGRU_Reduce_Input(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,no_aux_loss=True)


## fix dropout
class NetConvLSTM_LN_3L(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_3L"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        self.lstm2 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        self.lstm3 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)

        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDrop2D(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
        # learn with empty state
        #     assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=2 
        self.lstm.sample_mask(interim_thought.device)
        self.lstm2.sample_mask(interim_thought.device)
        self.lstm3.sample_mask(interim_thought.device)

        lstm_inp1 = self.lstm.forward_input(x)
        lstm_inp2 = self.lstm2.forward_input(x)
        lstm_inp3 = self.lstm3.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(lstm_inp1,state)
            interim_thought = self._state_drop(interim_thought)


            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (interim_thought,c2)

            interim_thought, c2 = self.lstm2(lstm_inp2,state2)
            interim_thought = self._state_drop(interim_thought)


            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (interim_thought,c3)
            
            interim_thought, c3 = self.lstm3(lstm_inp3,state3)
            interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    

def dt_convlstm_ln_onlyx_3l_diffc_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True)

def dt_convlstm_instn_onlyx_3l_diffc_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True)



def dt_convlstm_instn_onlyx_3l_diffc_pydrop02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout=0.2,dropout_method='pytorch',use_instance_norm=True)

def dt_convlstm_instn_onlyx_3l_diffc_gal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout=0.2,dropout_method='gal',use_instance_norm=True)

def dt_convlstm_instn_onlyx_3l_diffc_moon02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout=0.2,dropout_method='moon',use_instance_norm=True)

def dt_convlstm_instn_onlyx_3l_diffc_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                        _dropout=0.2,dropout_method='semeniuta',use_instance_norm=True)

def dt_convlstm_ln_onlyx_3l_diffc_pydrop02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='pytorch')

def dt_convlstm_ln_onlyx_3l_diffc_gal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='gal')

def dt_convlstm_ln_onlyx_3l_diffc_moon02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='moon')

def dt_convlstm_ln_onlyx_3l_diffc_sem02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout=0.2,dropout_method='semeniuta')

def dt_convlstm_ln_onlyx_3l_difc_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2)

def dt_convlstm_ln_onlyx_3l_difc_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3)

def dt_convlstm_instn_onlyx_3l_difc_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.2,use_instance_norm=True)

def dt_convlstm_instn_onlyx_3l_difc_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.3,use_instance_norm=True)



def dt_convlstm_in_noaff_onlyx_3l_diffc_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False)

def dt_convlstm_in_noaff_nopre_onlyx_3l_diffc_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,ln_preact=False)

def dt_convlstm_in_noaff_onlyx_3l_diffc_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,_dropout_gal2=0.2)

def dt_convlstm_in_noaff_nopre_onlyx_3l_diffc_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,_dropout_gal2=0.2,ln_preact=False)



def dt_convlstm_in_noaff_onlyx_3l_diffc_sgal01_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,_dropout_gal2=0.1)


def dt_convlstm_in_noaff_onlyx_3l_diffc_sgal02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,_dropout_gal2=0.2)

def dt_convlstm_in_noaff_onlyx_3l_diffc_sgal03_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                          use_instance_norm=True,norm_affine=False,_dropout_gal2=0.3)



def dt_convlstm_ln_nopre_onlyx_3ldc_sgal04_py01_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_3ldc_sgal04_py02_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch')

def dt_convlstm_ln_nopre_onlyx_3ldc_sgal04_py03_2d(width, **kwargs):
    return NetConvLSTM_LN_3L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')



## fix dropout
class NetConvLSTM_LN_3L_1Step(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_3L_1Step"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        self.lstm2 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        self.lstm3 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)

        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDrop2D(dropout=self._dropout_h)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))

        self.lstm.sample_mask(interim_thought.device)
        self.lstm2.sample_mask(interim_thought.device)
        self.lstm3.sample_mask(interim_thought.device)

        lstm_inp1 = self.lstm.forward_input(x)
        lstm_inp2 = self.lstm2.forward_input(x)
        lstm_inp3 = self.lstm3.forward_input(x)


        for i in range(iters_to_do):
            if i%3==0:
                if i==0:
                    self._state_drop.set_weights(interim_thought)
                    state = None
                else:
                    state = (interim_thought,c)
                    


                interim_thought, c = self.lstm(lstm_inp1,state)
                interim_thought = self._state_drop(interim_thought)
            elif i%3==1:

                if i==1:
                    state2=(interim_thought,torch.zeros_like(c).to(c.device))
                else:
                    # should it be state drop from other lstm though?
                    state2 = (interim_thought,c2)

                interim_thought, c2 = self.lstm2(lstm_inp2,state2)
                interim_thought = self._state_drop(interim_thought)

            elif i%3==2:

                if i==2:
                    state3=(interim_thought,torch.zeros_like(c).to(c.device))
                else:
                    state3 = (interim_thought,c3)
                
                interim_thought, c3 = self.lstm3(lstm_inp3,state3)
                interim_thought = self._state_drop(interim_thought)




            out = self.head(interim_thought)
            all_outputs[:, i] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def ln_nopre_onlyx_3ldc_1step(width, **kwargs):
    return NetConvLSTM_LN_3L_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     norm_affine=False,)

def ln_nopre_onlyx_3ldc_1step_sgal04_py01(width, **kwargs):
    return NetConvLSTM_LN_3L_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')

def ln_nopre_onlyx_3ldc_1step_sgal04_py02(width, **kwargs):
    return NetConvLSTM_LN_3L_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch')

def ln_nopre_onlyx_3ldc_1step_sgal04_py03(width, **kwargs):
    return NetConvLSTM_LN_3L_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')


def lstm_3ldc_1step(width, **kwargs):
    return NetConvLSTM_LN_3L_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     ln_preact=False,)



## fix dropout
class NetConvLSTM_LN_1Step(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellV2(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDrop2D(dropout=self._dropout_h)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        else:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        self.lstm.sample_mask(interim_thought.device)

        for i in range(iters_to_do):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(x,state)
            interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            out = self.head(interim_thought)
            all_outputs[:, i] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    

def ln_nopre_onlyx_1l_1step(width, **kwargs):
    return NetConvLSTM_LN_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     norm_affine=False,)

def ln_nopre_onlyx_1l_1step_sgal04_py01(width, **kwargs):
    return NetConvLSTM_LN_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.1,dropout_method='pytorch')

def ln_nopre_onlyx_1l_1step_sgal04_py02(width, **kwargs):
    return NetConvLSTM_LN_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.2,dropout_method='pytorch')

def ln_nopre_onlyx_1l_1step_sgal04_py03(width, **kwargs):
    return NetConvLSTM_LN_1Step(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')



class BlockLN2D(nn.Module):
    """Basic residual block class 2D"""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1, group_norm=False,bias=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=bias)
        self.gn1 = nn.GroupNorm(1, planes, affine=False) if group_norm else nn.Sequential()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        self.gn2 = nn.GroupNorm(1, planes, affine=False) if group_norm else nn.Sequential()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,
                                                    kernel_size=1, stride=stride, bias=bias))

    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class BlockLN2D_1x(nn.Module):
    """Basic residual block class 2D"""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1, group_norm=False,bias=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1,
                               stride=stride, padding=0, bias=bias)
        self.gn1 = nn.GroupNorm(1, planes, affine=False) if group_norm else nn.Sequential()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=1,
                               stride=1, padding=0, bias=bias)
        self.gn2 = nn.GroupNorm(1, planes, affine=False) if group_norm else nn.Sequential()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,
                                                    kernel_size=1, stride=stride, bias=bias))

    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


## improve dt net
class DTNetDiff(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                 gal_drop=0.0,py_drop=0, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, 2, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)


        self.gal_drop = gal_drop
        self.py_drop = py_drop
        self._state_drop = SampleDrop2D(dropout=self.gal_drop)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        do_dropout = self.training and self.py_drop > 0.0

        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, 2, x.size(2), x.size(3))).to(x.device)


        self._state_drop.set_weights(interim_thought)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)
            
            if do_dropout:
                interim_thought=F.dropout(interim_thought, p=self.py_drop, training=self.training)

            interim_thought = self._state_drop(interim_thought)

            out = self.head(interim_thought)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_ln(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0,py_drop=0)

def dt_net_recall_ln_1x(width, **kwargs):
    return DTNetDiff(BlockLN2D_1x, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0,py_drop=0)

def dt_net_recall_ln_sgal01_1x(width, **kwargs):
    return DTNetDiff(BlockLN2D_1x, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.1)
def dt_net_recall_ln_sgal02_1x(width, **kwargs):
    return DTNetDiff(BlockLN2D_1x, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.2)

def dt_net_recall_ln_sgal03_1x(width, **kwargs):
    return DTNetDiff(BlockLN2D_1x, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.3)

def dt_net_recall_ln_sgal01(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.1)

def dt_net_recall_ln_sgal02(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.2)

def dt_net_recall_ln_sgal03(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.3)

def dt_net_recall_ln_sgal_py01(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.1,py_drop=0.1)

def dt_net_recall_ln_sgal_py02(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.2,py_drop=0.2)

def dt_net_recall_ln_sgal_py03(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.3,py_drop=0.3)

def dt_net_recall_ln_sgal_py04(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.4,py_drop=0.4)

def dt_net_recall_ln_sgal_py05(width, **kwargs):
    return DTNetDiff(BlockLN2D, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.5,py_drop=0.5)

def dt_net_recall_gn_sgal_py01(width, **kwargs):
    return DTNetDiff(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.1,py_drop=0.1)


def dt_net_recall_gn_sgal01(width, **kwargs):
    return DTNetDiff(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.1,py_drop=0)

def dt_net_recall_gn_sgal04(width, **kwargs):
    return DTNetDiff(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True, group_norm=True,
                     gal_drop=0.4,py_drop=0)


## fix dropout
class NetConvLSTM_LN_5L(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,conv_dim: int = 2,
                  output_size=2,reduce=False,use_AvgPool=True,
                  **kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_5L"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        if conv_dim==2:
            conv_class = nn.Conv2d
        elif conv_dim==1:
            conv_class = nn.Conv1d
            in_channels=1 # yeah donno why... thanks original coders
        else:
            assert False, "not implemented"


        proj_conv = conv_class(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions
        
        if not reduce:
            if conv_dim==2:
                head_conv1 = conv_class(width, 32, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(32, 8, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(8, output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)

                if output_size!=2:
                    head_conv1 = conv_class(width, width, kernel_size=3,
                                        stride=1, padding=1, bias=bias)
                    head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                        stride=1, padding=1, bias=bias)
                    head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                        stride=1, padding=1, bias=bias)


            elif conv_dim==1:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)


            self.projection = nn.Sequential(proj_conv, nn.ReLU())
            # self.recur_block = nn.Sequential(*recur_layers)


            self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_conv2, nn.ReLU(),
                                    head_conv3)
            
        else:

            assert conv_dim==2, "not implemented conv1d"

            if use_AvgPool:
                head_pool = nn.AdaptiveAvgPool2d(output_size=1) 
                # if we do average we are dividing all outputs by the size of the image.
            else:
                head_pool = nn.AdaptiveMaxPool2d(output_size=1)
            
            if conv_dim==2:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(width, output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
            # elif conv_dim==1:
            #     head_conv1 = conv_class(width, width, kernel_size=3,
            #                         stride=1, padding=1, bias=bias)
            #     head_conv2 = conv_class(width, int(width/2), kernel_size=3,
            #                         stride=1, padding=1, bias=bias)
            #     head_conv3 = conv_class(int(width/2), 2, kernel_size=3,
            #                         stride=1, padding=1, bias=bias)

            self.projection = nn.Sequential(proj_conv, nn.ReLU())
            # self.recur_block = nn.Sequential(*recur_layers)
            self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_conv2, nn.ReLU(),
                                    head_conv3,
                                    
                                    head_pool,
                                    )

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        self.lstm2 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        self.lstm3 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        self.lstm4 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                      dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                        use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        
        self.lstm5 = ConvLSTMCellV3(in_channels, width, (3,3), True,
                                        dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                        use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        

        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDropND(dropout=self._dropout_h)


        self.output_size = output_size
        self.reduce = reduce


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
        # learn with empty state
        #     assert False, "not implemented"

        if self.reduce:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)
        else:

            if len(x.shape)==4:
                # if self.flatten:
                #     all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size*x.size(2)*x.size(3))).to(x.device)
                # else:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)
            elif len(x.shape)==3:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2))).to(x.device)
            else:
                assert False, "not implemented"
            # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
            # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=1
        self.lstm.sample_mask(interim_thought.device)
        self.lstm2.sample_mask(interim_thought.device)
        self.lstm3.sample_mask(interim_thought.device)
        self.lstm4.sample_mask(interim_thought.device)
        self.lstm5.sample_mask(interim_thought.device)

        lstm_inp1 = self.lstm.forward_input(x)
        lstm_inp2 = self.lstm2.forward_input(x)
        lstm_inp3 = self.lstm3.forward_input(x)
        lstm_inp4 = self.lstm4.forward_input(x)
        lstm_inp5 = self.lstm5.forward_input(x)

        for i in range(iters_to_do):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = None
            else:
                state = (interim_thought,c)
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought, c = self.lstm(lstm_inp1,state)
            interim_thought = self._state_drop(interim_thought)


            if i==0:
                state2=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                # should it be state drop from other lstm though?
                state2 = (interim_thought,c2)

            interim_thought, c2 = self.lstm2(lstm_inp2,state2)
            interim_thought = self._state_drop(interim_thought)


            if i==0:
                state3=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state3 = (interim_thought,c3)
            
            interim_thought, c3 = self.lstm3(lstm_inp3,state3)
            interim_thought = self._state_drop(interim_thought)


            if i==0:
                state4=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state4 = (interim_thought,c4)
            
            interim_thought, c4 = self.lstm4(lstm_inp4,state4)
            interim_thought = self._state_drop(interim_thought)

            if i==0:
                state5=(interim_thought,torch.zeros_like(c).to(c.device))
            else:
                state5 = (interim_thought,c5)
            
            interim_thought, c5 = self.lstm5(lstm_inp5,state5)
            interim_thought = self._state_drop(interim_thought)
            

            if i%mul==mul-1:
                if self.reduce:
                    out = self.head(interim_thought).view(x.size(0), self.output_size)
                else:
                    out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    



def dt_convlstm_ln_nopre_onlyx_5l_sgal04_py03_2d(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch')


def dt_convlstm_ln_nopre_onlyx_5l_sgal05_py05_2d(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.5,norm_affine=False,
                    _dropout=0.5,dropout_method='pytorch')


def dt_convlstm_ln_nopre_onlyx_5l_py05_2d(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0.5,dropout_method='pytorch')



def dt_convlstm_ln_nopre_onlyx_5l_2d(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.,norm_affine=False,
                    _dropout=0.,dropout_method='pytorch')



def dt_convlstm_ln_nopre_onlyx_5l_sgal04_py03_1d(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1)

def dt_convlstm_ln_5l_sgal04_py03_2d_out4_avgpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',reduce=True,use_AvgPool=True,
                    output_size=4)

def dt_convlstm_ln_5l_sgal04_py03_2d_out3_avgpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',reduce=True,use_AvgPool=True,
                    output_size=3)

def dt_convlstm_ln_5l_sgal04_py03_2d_out3_maxpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',reduce=True,use_AvgPool=False,
                    output_size=3)

def dt_convlstm_ln_5l_sgal04_py03_2d_out4_maxpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',reduce=True,use_AvgPool=False,
                    output_size=4)


def dt_convlstm_ln_5l_out4_maxpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,dropout_method='pytorch',reduce=True,use_AvgPool=False,
                    output_size=4)

def dt_convlstm_ln_5l_sgal04_py03_2d_out10_avgpool(width, **kwargs):
    return NetConvLSTM_LN_5L(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',reduce=True,use_AvgPool=True,
                    output_size=10)




## fix dropout
class NetConvLSTM_LN_NL(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,conv_dim: int = 2,
                  output_size=2,nr_lstm_layers = 5,mul=5,skip_mul=1,use_skip=False,
                  **kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_NL"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        if conv_dim==2:
            conv_class = nn.Conv2d
        elif conv_dim==1:
            conv_class = nn.Conv1d
            in_channels=1 # yeah donno why... thanks original coders
        else:
            assert False, "not implemented"


        proj_conv = conv_class(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions
        if conv_dim==2:
            head_conv1 = conv_class(width, 32, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv2 = conv_class(32, 8, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv3 = conv_class(8, output_size, kernel_size=3,
                                stride=1, padding=1, bias=bias)

            if output_size!=2:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)


        elif conv_dim==1:
            head_conv1 = conv_class(width, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                stride=1, padding=1, bias=bias)


        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3)

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)


        # self.lstm = ConvLSTMCellV3(in_channels, width, (3,3), True,
        #                            dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                            use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        # self.lstm2 = ConvLSTMCellV3(in_channels, width, (3,3), True,
        #                            dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                            use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        # self.lstm3 = ConvLSTMCellV3(in_channels, width, (3,3), True,
        #                            dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                            use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        # self.lstm4 = ConvLSTMCellV3(in_channels, width, (3,3), True,
        #                               dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                                 use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        
        # self.lstm5 = ConvLSTMCellV3(in_channels, width, (3,3), True,
        #                                 dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                                 use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim)
        
        self.lstm_list = nn.ModuleList()

        for i in range(nr_lstm_layers):
            self.lstm_list.append(ConvLSTMCellV3(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine, conv_dim=conv_dim))


        self.nr_lstm_layers = nr_lstm_layers
        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDropND(dropout=self._dropout_h)


        self.output_size = output_size
        self.mul = mul
        self.use_skip = use_skip
        self.skip_mul = skip_mul




    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
        # learn with empty state
        #     assert False, "not implemented"

        if len(x.shape)==4:
            # if self.flatten:
            #     all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size*x.size(2)*x.size(3))).to(x.device)
            # else:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)
        elif len(x.shape)==3:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2))).to(x.device)
        else:
            assert False, "not implemented"
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=self.mul
        ## init mask
        # self.lstm.sample_mask(interim_thought.device)
        # self.lstm2.sample_mask(interim_thought.device)
        # self.lstm3.sample_mask(interim_thought.device)
        # self.lstm4.sample_mask(interim_thought.device)
        # self.lstm5.sample_mask(interim_thought.device)
        # lstm_inp1 = self.lstm.forward_input(x)
        # lstm_inp2 = self.lstm2.forward_input(x)
        # lstm_inp3 = self.lstm3.forward_input(x)
        # lstm_inp4 = self.lstm4.forward_input(x)
        # lstm_inp5 = self.lstm5.forward_input(x)

        lstm_inputs = []
        for lstm in self.lstm_list:
            lstm.sample_mask(interim_thought.device)
            lstm_inputs.append(lstm.forward_input(x))

        
        c_states = [None for i in range(self.nr_lstm_layers)]
        lstm_counter = 0

        previous_skip = interim_thought

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)



            if c_states[lstm_counter] is None:
                state = None
            else:
                state = (interim_thought,c_states[lstm_counter])
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought_new, c = self.lstm_list[lstm_counter](lstm_inputs[lstm_counter],state)
            interim_thought_new = self._state_drop(interim_thought_new)


            # if self.use_skip and i%self.skip_mul==self.skip_mul-1:
            if self.training and np.random.rand()<0.5: # random skip
                # ## it learns the random ness and needs it to extrapolate
                # previous_skip = interim_thought # new version
                # interim_thought = previous_skip + interim_thought_new

                # this version works alot better, but does it generalize better?
                interim_thought = interim_thought + interim_thought_new
            else:
                interim_thought = interim_thought_new



            c_states[lstm_counter] = c

            lstm_counter += 1
            lstm_counter = lstm_counter % self.nr_lstm_layers
            

            if i%mul==mul-1:
                out = self.head(interim_thought)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs

def convlstm_ln_nopre_onlyx_5l_sgal04_py03_2d(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',nr_lstm_layers=5)

# def convlstm_ln_nopre_onlyx_2l_2d_mul2(width, **kwargs):
#     return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
#                      _dropout_gal2=0.4,norm_affine=False,
#                     _dropout=0.3,dropout_method='pytorch',
#                     nr_lstm_layers=2,mul=2)

def convlstm_ln_nopre_onlyx_2l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=2,mul=5)

def convlstm_ln_nopre_onlyx_4l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=4,mul=5)

def convlstm_ln_nopre_onlyx_6l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=6,mul=5)

def convlstm_ln_nopre_onlyx_2l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=2,mul=5)

def convlstm_ln_nopre_onlyx_3l_1d_mul3(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=3,mul=3)

def convlstm_ln_nopre_onlyx_5l_sgal04_py03_1d(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=5)


def convlstm_ln_nopre_onlyx_10l_sgal04_py03_1d(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=10)


def convlstm_ln_nopre_onlyx_20l_sgal04_py03_1d(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=20)

def convlstm_ln_nopre_onlyx_20l_sgal04_py03_1d_mul2(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=20,mul=2)



def convlstm_ln_nopre_onlyx_20l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=20,mul=5)

def convlstm_ln_nopre_onlyx_20l_1d_mul20(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=20,mul=20)

def convlstm_ln_nopre_onlyx_20l_1d_mul20_skip1(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=20,mul=20,use_skip=True,skip_mul=1)

def convlstm_ln_nopre_onlyx_20l_sgal04_py03_1d_mul20_skip1(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=20,mul=20,use_skip=True,skip_mul=1)


def convlstm_ln_nopre_onlyx_1l_1d_mul5(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=1,mul=5)





def convlstm_ln_nopre_onlyx_1l_1d_mul5_skip1(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=1,mul=5,use_skip=True,skip_mul=1)


def convlstm_ln_nopre_onlyx_1l_sgal04_py03_1d_skip1(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    conv_dim=1,nr_lstm_layers=1,mul=5,use_skip=True,skip_mul=1)

def convlstm_ln_nopre_onlyx_1l_2d_mul5_skip1(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    nr_lstm_layers=1,mul=5,use_skip=True,skip_mul=1)

def convlstm_ln_nopre_onlyx_1l_1d_mul5_skip2(width, **kwargs):
    return NetConvLSTM_LN_NL(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,
                    conv_dim=1,nr_lstm_layers=1,mul=5,use_skip=True,skip_mul=2)



## fix dropout
class NetConvLSTM_LN_Reduce(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, width, output_size, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',ln_preact=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,
                  lstm_class=ConvLSTMCellV3,
                  conv_dim: int = 2,
                  use_AvgPool=False,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN_Reduce"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm

        if conv_dim==2:
            conv_class = nn.Conv2d
        elif conv_dim==1:
            conv_class = nn.Conv1d
            in_channels=1 # yeah donno why... thanks original coders
        else:
            assert False, "not implemented"


        proj_conv = conv_class(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        # conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions
        
        assert conv_dim==2, "not implemented conv1d"

        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1) 
            # if we do average we are dividing all outputs by the size of the image.
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)
        
        if conv_dim==2:
            head_conv1 = conv_class(width, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv2 = conv_class(width, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)
            head_conv3 = conv_class(width, output_size, kernel_size=3,
                                stride=1, padding=1, bias=bias)
        # elif conv_dim==1:
        #     head_conv1 = conv_class(width, width, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)
        #     head_conv2 = conv_class(width, int(width/2), kernel_size=3,
        #                         stride=1, padding=1, bias=bias)
        #     head_conv3 = conv_class(int(width/2), 2, kernel_size=3,
        #                         stride=1, padding=1, bias=bias)

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        # self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3,
                                  
                                  head_pool,
                                  )

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        self.lstm = lstm_class(in_channels, width, (3,3), True,
                                   dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
                                   use_instance_norm=use_instance_norm,learnable=norm_affine,
                                   conv_dim=conv_dim)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDropND(dropout=self._dropout_h)

        self.output_size = output_size

        assert ConvLSTMCellV3 == lstm_class or ConvLSTMCellV3_NoLayerNormalization == lstm_class

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is not None:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)


        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(x.device)

        lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                state = None
            else:
                state = (interim_thought,c)
                

            interim_thought, c = self.lstm(lstm_inp1,state)
            if i==0:
                self._state_drop.set_weights(interim_thought)
            interim_thought = self._state_drop(interim_thought)


            if i%mul==mul-1:
                out = self.head(interim_thought).view(x.size(0), self.output_size)
                all_outputs[:, i//mul] = out


        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


    def forward_return_hidden(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is not None:
            assert False, "not implemented"

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        all_hidden = torch.zeros((x.size(0), iters_to_do, self.width, x.size(2), x.size(3))).to(x.device)

        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(x.device)

        lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                state = None
            else:
                state = (interim_thought,c)
                

            interim_thought, c = self.lstm(lstm_inp1,state)

            if i==0:
                self._state_drop.set_weights(interim_thought)
            interim_thought = self._state_drop(interim_thought)


            if i%mul==mul-1:
                out = self.head(interim_thought).view(x.size(0), self.output_size)
                all_outputs[:, i//mul] = out

                all_hidden[:, i//mul] = interim_thought


        return all_outputs, all_hidden
    

    def forward_nomemory(self, x,targets,get_predicted_fn, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        # initial_thought = self.projection(x)

        if interim_thought is not None:
            assert False, "not implemented"
            # interim_thought = interim_thought

        # all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)
        corrects = torch.zeros(iters_to_do)

        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        self.lstm.sample_mask(x.device)

        lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                state = None
            else:
                state = (interim_thought,c)
                

            interim_thought, c = self.lstm(lstm_inp1,state)
            if i==0:
                self._state_drop.set_weights(interim_thought)
            interim_thought = self._state_drop(interim_thought)


            if i%mul==mul-1:
                out = self.head(interim_thought).view(x.size(0), self.output_size)
                # all_outputs[:, i//mul] = out
                ## process output
                predicted = get_predicted_fn(x,out) #get_predicted(x, out, problem)
                targets = targets.view(targets.size(0), -1)
                corrects[i] += torch.amin(predicted == targets, dim=[1]).sum().item()


        if self.training:
            if return_all_outputs:
                return corrects, out, interim_thought
            else:
                return out, interim_thought

        return corrects
# def dt_convlstm_ln_1l_sgal04_py03_2d_out4(width, **kwargs):
#     return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
#                      _dropout_gal2=0.4,norm_affine=False,
#                     _dropout=0.3,dropout_method='pytorch')


def dt_convlstm_ln_1l_sgal04_py03_2d_out4_avgpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=True)

def dt_convlstm_ln_1l_sgal04_py03_2d_out10_avgpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=10, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=True)


def dt_convlstm_ln_1l_2d_out4_avgpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                    #  _dropout_gal2=0.4,
                     norm_affine=False,
                    # _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=True)

def dt_convlstm_ln_1l_2d_out3_avgpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                    #  _dropout_gal2=0.4,
                     norm_affine=False,
                    # _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=True)

def dt_convlstm_ln_1l_sgal04_py03_2d_out3_avgpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=True)

def dt_convlstm_ln_1l_sgal04_py03_2d_out4_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=False)


def dt_convlstm_noln_1l_sgal04_py03_2d_out4_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',lstm_class=ConvLSTMCellV3_NoLayerNormalization,
                    use_AvgPool=False)


def dt_convlstm_noln_1l_sgal04_py03_2d_out3_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',lstm_class=ConvLSTMCellV3_NoLayerNormalization,
                    use_AvgPool=False)

def dt_convlstm_ln_1l_sgal04_py03_2d_out3_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,
                    _dropout=0.3,dropout_method='pytorch',
                    use_AvgPool=False)

def dt_convlstm_ln_1l_2d_out4_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,dropout_method='pytorch',
                    use_AvgPool=False)

def dt_convlstm_ln_1l_2d_out3_maxpool(width, **kwargs):
    return NetConvLSTM_LN_Reduce(width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,
                    _dropout=0,dropout_method='pytorch',
                    use_AvgPool=False)




class FeedForwardNetAvgPoolEnd(nn.Module):
    """Modified Residual Network model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, max_iters=8, group_norm=False,
                 use_AvgPool=True, paper_out_head=False,):
        super().__init__()

        self.width = int(width)
        self.recall = recall
        self.group_norm = group_norm

        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3, stride=1, padding=1, bias=False)

        if self.recall:
            self.recall_layer = nn.Conv2d(width+in_channels, width, kernel_size=3,
                                          stride=1, padding=1, bias=False)
        else:
            self.recall_layer = nn.Sequential()

        self.feedforward_layers = nn.ModuleList()
        for _ in range(max_iters):
            internal_block = []
            for j in range(len(num_blocks)):
                internal_block.append(self._make_layer(block, width, num_blocks[j], stride=1))
            self.feedforward_layers.append(nn.Sequential(*internal_block))


        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1) 
            # if we do average we are dividing all outputs by the size of the image.
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)
        
        head_conv1 = nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv2 = nn.Conv2d(width, 32, kernel_size=3, stride=1, padding=1, bias=False)
        head_conv3 = nn.Conv2d(32, output_size, kernel_size=3, stride=1, padding=1, bias=False)

        if paper_out_head:
            head_conv1 = nn.Conv2d(width, 64, kernel_size=3, stride=1, padding=1, bias=False)
            head_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
            head_conv3 = nn.Conv2d(64, output_size, kernel_size=3, stride=1, padding=1, bias=False)


        self.iters = max_iters
        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.head = nn.Sequential(head_conv1, nn.ReLU(), head_conv2, nn.ReLU(), head_conv3, head_pool)

        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, self.group_norm))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, iters_elapsed=0, **kwargs):
        # assert (iters_elapsed + iters_to_do) <= self.iters
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i, layer in enumerate(self.feedforward_layers[iters_elapsed:iters_elapsed+iters_to_do]):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
                interim_thought = self.recall_layer(interim_thought)
            interim_thought = layer(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)
            all_outputs[:, i] = out

        if iters_to_do > self.iters:
            # fill in the rest with the last output
            all_outputs[:, self.iters:] = out.unsqueeze(1).repeat(1, iters_to_do - self.iters, 1)

        if self.training:
            return out, interim_thought
        else:
            return all_outputs


def feedforward_net_recall_2d_out4_avgpool(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=4, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])


def feedforward_net_recall_2d_out3_avgpool(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=3, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"])

def feedforward_net_recall_2d_out4_maxpool(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=4, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"],use_AvgPool=False)

def feedforward_net_recall_2d_out3_maxpool(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=3, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"],use_AvgPool=False)



def feedforward_net_recall_2d_out4_maxpool_fixhead(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=4, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"],use_AvgPool=False,
                          paper_out_head=True)


def feedforward_net_recall_2d_out3_maxpool_fixhead(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=3, in_channels=kwargs["in_channels"],
                          recall=True, max_iters=kwargs["max_iters"],use_AvgPool=False,paper_out_head=True)



def feedforward_net_2d_out3_avgpool(width, **kwargs):
    return FeedForwardNetAvgPoolEnd(BasicBlock, [2], width, output_size=3, in_channels=kwargs["in_channels"],
                          recall=False, max_iters=kwargs["max_iters"])



class DTNetAVGpool(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False,bias=False,
        use_AvgPool=True,paper_out_head=False,
     **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, output_size, kernel_size=3,
                               stride=1, padding=1, bias=bias)


        if paper_out_head:
            head_conv1 = nn.Conv2d(width, 64, kernel_size=3, stride=1, padding=1, bias=False)
            head_conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
            head_conv3 = nn.Conv2d(64, output_size, kernel_size=3, stride=1, padding=1, bias=False)



        if use_AvgPool:
            head_pool = nn.AdaptiveAvgPool2d(output_size=1) 
            # if we do average we are dividing all outputs by the size of the image.
        else:
            head_pool = nn.AdaptiveMaxPool2d(output_size=1)
        

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3,head_pool)


        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x)

        if interim_thought is None:
            interim_thought = initial_thought

        all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)

        for i in range(iters_to_do):
            if self.recall:
                interim_thought = torch.cat([interim_thought, x], 1)
            interim_thought = self.recur_block(interim_thought)

            # out = self.head(interim_thought)
            out = self.head(interim_thought).view(x.size(0), self.output_size)
            all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_recall_2d_out3_avg_pool(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True)


def dt_net_recall_2d_out4_avg_pool(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True)

def dt_net_recall_2d_out3_maxpool(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                        use_AvgPool=False)

def dt_net_recall_2d_out3_maxpool_fixhead(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True,
                        use_AvgPool=False,paper_out_head=True)


def dt_net_recall_2d_out4_maxpool(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                        use_AvgPool=False)


def dt_net_recall_2d_out4_maxpool_fixhead(width, **kwargs):
    return DTNetAVGpool(BasicBlock, [2], width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True,
                        use_AvgPool=False,paper_out_head=True)

class DTNetRandom(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, output_size, in_channels=3, recall=True, group_norm=False,bias=False, **kwargs):
        super().__init__()

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm
        proj_conv = nn.Conv2d(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = nn.Conv2d(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        recur_layers = []
        if recall:
            recur_layers.append(conv_recall)

        for i in range(len(num_blocks)):
            recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        head_conv1 = nn.Conv2d(width, 32, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv2 = nn.Conv2d(32, 8, kernel_size=3,
                               stride=1, padding=1, bias=bias)
        head_conv3 = nn.Conv2d(8, output_size, kernel_size=3,
                               stride=1, padding=1, bias=bias)

        head_pool = nn.AdaptiveAvgPool2d(output_size=1) 

        self.projection = nn.Sequential(proj_conv, nn.ReLU())
        self.recur_block = nn.Sequential(*recur_layers)
        self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                  head_conv2, nn.ReLU(),
                                  head_conv3,head_pool)


        self.output_size = output_size

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
       

        all_outputs = torch.rand((x.size(0), iters_to_do, self.output_size),requires_grad=True).to(x.device)

        # for i in range(iters_to_do):
        #     if self.recall:
        #         interim_thought = torch.cat([interim_thought, x], 1)
        #     interim_thought = self.recur_block(interim_thought)

        #     # out = self.head(interim_thought)
        #     out = self.head(interim_thought).view(x.size(0), self.output_size)
        #     all_outputs[:, i] = out

        if self.training:
            out = all_outputs[:, -1]
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs


def dt_net_random_out3(width, **kwargs):
    return DTNetRandom(BasicBlock, [2], width=width,output_size=3, in_channels=kwargs["in_channels"], recall=True)

def dt_net_random_out4(width, **kwargs):
    return DTNetRandom(BasicBlock, [2], width=width,output_size=4, in_channels=kwargs["in_channels"], recall=True)

## fix dropout
class NetConvNOLSTM_LN(nn.Module):
    """DeepThinking Network 2D model class"""

    def __init__(self, block, num_blocks, width, in_channels=3, recall=True, group_norm=False,bias=False,
                  _dropout=0,dropout_method='pytorch',use_ln=True,use_instance_norm=False,
                  _dropout_gal2=0,norm_affine=True,
                  lstm_class=ConvLSTMCellV3,
                  conv_dim: int = 2, reduce=False,use_AvgPool=True,
                  output_size=2,flatten=False,act=F.relu,**kwargs):
        super().__init__()

        self.name = "NetConvLSTM_LN"

        self.bias = bias

        self.recall = recall
        self.width = int(width)
        self.group_norm = group_norm

        if conv_dim==2:
            conv_class = nn.Conv2d
        elif conv_dim==1:
            conv_class = nn.Conv1d
            in_channels=1 # yeah donno why... thanks original coders
        else:
            assert False, "not implemented"


        proj_conv = conv_class(in_channels, width, kernel_size=3,
                              stride=1, padding=1, bias=bias)

        conv_recall = conv_class(width + in_channels, width, kernel_size=3,
                                stride=1, padding=1, bias=bias)

        self.conv_recall = conv_recall
        self.ln_cell = nn.GroupNorm(1,width, affine=norm_affine)
        self.use_ln = use_ln

        self.act = act
        # recur_layers = []
        # if recall:
        #     recur_layers.append(conv_recall)

        # for i in range(len(num_blocks)):
        #     recur_layers.append(self._make_layer(block, width, num_blocks[i], stride=1))

        # each recurrent layer runs 1 + 2*2 convolutions = 5 convolutions
        if not reduce:
            if conv_dim==2:
                head_conv1 = conv_class(width, 32, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(32, 8, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(8, output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)

                if output_size!=2:
                    head_conv1 = conv_class(width, width, kernel_size=3,
                                        stride=1, padding=1, bias=bias)
                    head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                        stride=1, padding=1, bias=bias)
                    head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                        stride=1, padding=1, bias=bias)


            elif conv_dim==1:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, int(width/2), kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(int(width/2), output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)

            self.projection = nn.Sequential(proj_conv, nn.ReLU())
            # self.recur_block = nn.Sequential(*recur_layers)
            self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_conv2, nn.ReLU(),
                                    head_conv3)
        else:
            assert conv_dim==2, "not implemented conv1d"

            if use_AvgPool:
                head_pool = nn.AdaptiveAvgPool2d(output_size=1) 
                # if we do average we are dividing all outputs by the size of the image.
            else:
                head_pool = nn.AdaptiveMaxPool2d(output_size=1)
            
            if conv_dim==2:
                head_conv1 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv2 = conv_class(width, width, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
                head_conv3 = conv_class(width, output_size, kernel_size=3,
                                    stride=1, padding=1, bias=bias)
            # elif conv_dim==1:
            #     head_conv1 = conv_class(width, width, kernel_size=3,
            #                         stride=1, padding=1, bias=bias)
            #     head_conv2 = conv_class(width, int(width/2), kernel_size=3,
            #                         stride=1, padding=1, bias=bias)
            #     head_conv3 = conv_class(int(width/2), 2, kernel_size=3,
            #                         stride=1, padding=1, bias=bias)

            self.projection = nn.Sequential(proj_conv, nn.ReLU())
            # self.recur_block = nn.Sequential(*recur_layers)
            self.head = nn.Sequential(head_conv1, nn.ReLU(),
                                    head_conv2, nn.ReLU(),
                                    head_conv3,
                                    
                                    head_pool,
                                    )

        assert dropout_method in ['pytorch','gal','moon','semeniuta','input']

        # self.lstm = LSTMCellImproved(width, width)
        # self.lstm = lstm_class(in_channels, width, (3,3), True,
        #                            dropout=_dropout,ln_preact=ln_preact,dropout_method=dropout_method,
        #                            use_instance_norm=use_instance_norm,learnable=norm_affine,
        #                            conv_dim=conv_dim)
        # self.lstm2 = ConvLSTMCell(in_channels, width, (3,3), True)
        # self.lstm3 = ConvLSTMCell(in_channels, width, (3,3), True)


        self._dropout_h = _dropout_gal2
        self._state_drop = SampleDropND(dropout=self._dropout_h)

        self.output_size = output_size

        assert ConvLSTMCellV3 == lstm_class

        self.flatten = flatten
        self.reduce = reduce

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for strd in strides:
            layers.append(block(self.width, planes, strd, group_norm=self.group_norm, bias=self.bias))
            self.width = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, iters_to_do, interim_thought=None, return_all_outputs=False, **kwargs):
        initial_thought = self.projection(x) # does nothing....

        if interim_thought is None:
            interim_thought = initial_thought
        # else:
            # assert False, "not implemented"
            # interim_thought = interim_thought

        if self.reduce:
            all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size)).to(x.device)
        else:
            if len(x.shape)==4:
                if self.flatten:
                    all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size*x.size(2)*x.size(3))).to(x.device)
                else:
                    all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2), x.size(3))).to(x.device)
            elif len(x.shape)==3:
                all_outputs = torch.zeros((x.size(0), iters_to_do, self.output_size, x.size(2))).to(x.device)
            else:
                assert False, "not implemented"
        
        # interim_thought_h, c = self.lstm(self._input_drop(interim_thought_flat))
        # interim_thought = rearrange(interim_thought_h, '(b h w) c -> b c h w', b=x.size(0), h=x.size(2), w=x.size(3))
        
        mul=5 # para ser equivalente a 3 lstms
        # self.lstm.sample_mask(interim_thought.device)

        # lstm_inp1 = self.lstm.forward_input(x)

        for i in range(iters_to_do*mul):
            if i==0:
                self._state_drop.set_weights(interim_thought)
                state = torch.zeros_like(interim_thought).to(interim_thought.device)
            else:
                state = interim_thought
                

            # if self.recall:
            #     interim_thought_new = torch.cat([self._input_drop(interim_thought), x], 1)
            # else:
            #     assert False, "not implemented"

            interim_thought = self.act(self.conv_recall(torch.cat([x,state],1))) #self.lstm(lstm_inp1,state)
            
            if self.use_ln:
                interim_thought = self.ln_cell(interim_thought)

            interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state2=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # #     # should it be state drop from other lstm though?
            # state2 = (interim_thought,c)

            # interim_thought, c = self.lstm2(x,state2)
            # interim_thought = self._state_drop(interim_thought)

            # # if i==0:
            # #     state3=(interim_thought,torch.zeros_like(c).to(c.device))
            # # else:
            # state3 = (interim_thought,c)
            
            # interim_thought, c = self.lstm3(x,state3)
            # interim_thought = self._state_drop(interim_thought)

            if i%mul==mul-1:
                if self.reduce:
                    out = self.head(interim_thought).view(x.size(0), self.output_size)
                else:
                    out = self.head(interim_thought)


                if self.flatten:
                    out = out.flatten(1)
                all_outputs[:, i//mul] = out

            # out = self.head(interim_thought)
            # all_outputs[:, i] = out

        if self.training:
            if return_all_outputs:
                return all_outputs, out, interim_thought
            else:
                return out, interim_thought

        return all_outputs
    

def dt_conv_nolstm_ln_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,use_ln=True,
                    _dropout=0.3,dropout_method='pytorch')

def dt_conv_nolstm_noln_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,use_ln=False,
                    _dropout=0.3,dropout_method='pytorch')

def dt_conv_nolstm_leaky_noln_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,use_ln=False,
                    _dropout=0.3,dropout_method='pytorch',
                    act=F.leaky_relu)

def dt_conv_nolstm_noln_1l_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch')


def dt_conv_nolstm_mish_ln_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,use_ln=True,
                    _dropout=0.3,dropout_method='pytorch',
                    act=F.mish)

def dt_conv_nolstm_leaky_ln_1l_sgal04_py03_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0.4,norm_affine=False,use_ln=True,
                    _dropout=0.3,dropout_method='pytorch',
                    act=F.leaky_relu)



def dt_conv_nolstm_mish_ln_1l_2d(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=True,
                    _dropout=0,dropout_method='pytorch',
                    act=F.mish)




def dt_conv_nolstm_noln_1l_2d_out4_avgpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch',
                    reduce=True,output_size=4,use_AvgPool=True)

def dt_conv_nolstm_noln_1l_2d_out4_maxpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch',
                    reduce=True,output_size=4,use_AvgPool=False)

def dt_conv_nolstm_noln_1l_2d_out3_avgpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch',
                    reduce=True,output_size=3,use_AvgPool=True)

def dt_conv_nolstm_noln_1l_2d_out3_maxpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch',
                    reduce=True,output_size=3,use_AvgPool=False)

def dt_conv_nolstm_noln_1l_2d_out10_avgpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=False,
                    _dropout=0,dropout_method='pytorch',
                    reduce=True,output_size=10,use_AvgPool=True)



def dt_conv_nolstm_mish_ln_1l_2d_out4_avgpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=True,
                    _dropout=0,dropout_method='pytorch',
                    act=F.mish,
                    reduce=True,output_size=4,use_AvgPool=True)


def dt_conv_nolstm_mish_ln_1l_2d_out3_avgpool(width, **kwargs):
    return NetConvNOLSTM_LN(BasicBlock, [2], width=width, in_channels=kwargs["in_channels"], recall=True,
                     _dropout_gal2=0,norm_affine=False,use_ln=True,
                    _dropout=0,dropout_method='pytorch',
                    act=F.mish,
                    reduce=True,output_size=3,use_AvgPool=True)
